Update delta-iris/src/world_model.py
Browse files
delta-iris/src/world_model.py
CHANGED
|
@@ -70,16 +70,3 @@ class WorldModel(nn.Module):
|
|
| 70 |
"logits_rewards": logits_rewards,
|
| 71 |
"logits_ends": logits_ends
|
| 72 |
}
|
| 73 |
-
|
| 74 |
-
@torch.no_grad()
|
| 75 |
-
def burn_in(self, obs: torch.FloatTensor, act: torch.LongTensor, latent_tokens: torch.LongTensor, use_kv_cache: bool = False) -> torch.FloatTensor:
|
| 76 |
-
assert obs.size(1) == act.size(1) + 1 == latent_tokens.size(1) + 1
|
| 77 |
-
|
| 78 |
-
x_emb = self.frame_cnn(obs)
|
| 79 |
-
act_emb = rearrange(self.act_emb(act), 'b t e -> b t 1 e')
|
| 80 |
-
q_emb = self.latents_emb(latent_tokens)
|
| 81 |
-
x_a_q = rearrange(torch.cat((x_emb[:, :-1], act_emb, q_emb), dim=2), 'b t k2 e -> b (t k2) e')
|
| 82 |
-
wm_input_sequence = torch.cat((x_a_q, x_emb[:, -1]), dim=1)
|
| 83 |
-
wm_output_sequence = self(wm_input_sequence, use_kv_cache=use_kv_cache).output_sequence
|
| 84 |
-
|
| 85 |
-
return wm_output_sequence
|
|
|
|
| 70 |
"logits_rewards": logits_rewards,
|
| 71 |
"logits_ends": logits_ends
|
| 72 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|