Spaces:
Runtime error
Runtime error
improve pred
Browse files- app.py +1 -1
- sim/simulator.py +9 -5
app.py
CHANGED
|
@@ -14,7 +14,7 @@ genie = GenieSimulator(
|
|
| 14 |
quantize=False,
|
| 15 |
backbone_type='stmar',
|
| 16 |
backbone_ckpt='data/mar_ckpt/langtable',
|
| 17 |
-
prompt_horizon=
|
| 18 |
action_stride=1,
|
| 19 |
domain='language_table',
|
| 20 |
)
|
|
|
|
| 14 |
quantize=False,
|
| 15 |
backbone_type='stmar',
|
| 16 |
backbone_ckpt='data/mar_ckpt/langtable',
|
| 17 |
+
prompt_horizon=2,
|
| 18 |
action_stride=1,
|
| 19 |
domain='language_table',
|
| 20 |
)
|
sim/simulator.py
CHANGED
|
@@ -248,22 +248,26 @@ class GenieSimulator(LearnedSimulator):
|
|
| 248 |
# encoding
|
| 249 |
input_latent_states = torch.cat([
|
| 250 |
self.cached_latent_frames,
|
| 251 |
-
torch.zeros_like(self.cached_latent_frames[
|
| 252 |
]).unsqueeze(0).to(torch.float32)
|
| 253 |
|
|
|
|
|
|
|
| 254 |
# dtype conversion and mask token
|
| 255 |
if self.backbone_type == "stmaskgit":
|
| 256 |
input_latent_states = input_latent_states.long()
|
| 257 |
-
input_latent_states[:,
|
| 258 |
elif self.backbone_type == "stmar":
|
| 259 |
-
input_latent_states[:,
|
| 260 |
|
| 261 |
# dynamics rollout
|
| 262 |
action = torch.from_numpy(action).to(device=self.device)
|
| 263 |
input_actions = torch.cat([ # (1, prompt_horizon + 1, action_stride * A)
|
| 264 |
self.cached_actions,
|
| 265 |
-
action.unsqueeze(0)
|
| 266 |
-
|
|
|
|
|
|
|
| 267 |
|
| 268 |
if self.measure_step_time:
|
| 269 |
start_time = time.time()
|
|
|
|
| 248 |
# encoding
|
| 249 |
input_latent_states = torch.cat([
|
| 250 |
self.cached_latent_frames,
|
| 251 |
+
torch.zeros_like(self.cached_latent_frames[[0]]),
|
| 252 |
]).unsqueeze(0).to(torch.float32)
|
| 253 |
|
| 254 |
+
input_latent_states = input_latent_states[:, :self.prompt_horizon + 1]
|
| 255 |
+
|
| 256 |
# dtype conversion and mask token
|
| 257 |
if self.backbone_type == "stmaskgit":
|
| 258 |
input_latent_states = input_latent_states.long()
|
| 259 |
+
input_latent_states[:, -1] = self.backbone.mask_token_id
|
| 260 |
elif self.backbone_type == "stmar":
|
| 261 |
+
input_latent_states[:, -1] = self.backbone.mask_token
|
| 262 |
|
| 263 |
# dynamics rollout
|
| 264 |
action = torch.from_numpy(action).to(device=self.device)
|
| 265 |
input_actions = torch.cat([ # (1, prompt_horizon + 1, action_stride * A)
|
| 266 |
self.cached_actions,
|
| 267 |
+
action.unsqueeze(0),
|
| 268 |
+
action.unsqueeze(0) # the last action is not used, but we need a_{t-1}, s_{t-1} to predict s_t
|
| 269 |
+
]).view(1, -1, action.shape[-1]).to(torch.float32) # + 1
|
| 270 |
+
input_actions = input_actions[:, :self.prompt_horizon + 1]
|
| 271 |
|
| 272 |
if self.measure_step_time:
|
| 273 |
start_time = time.time()
|