ShaswatRobotics commited on
Commit
3e9b1c7
·
verified ·
1 Parent(s): 55c5877

Update delta-iris/src/world_model.py

Browse files
Files changed (1) hide show
  1. delta-iris/src/world_model.py +0 -13
delta-iris/src/world_model.py CHANGED
@@ -70,16 +70,3 @@ class WorldModel(nn.Module):
70
  "logits_rewards": logits_rewards,
71
  "logits_ends": logits_ends
72
  }
73
-
74
- @torch.no_grad()
75
- def burn_in(self, obs: torch.FloatTensor, act: torch.LongTensor, latent_tokens: torch.LongTensor, use_kv_cache: bool = False) -> torch.FloatTensor:
76
- assert obs.size(1) == act.size(1) + 1 == latent_tokens.size(1) + 1
77
-
78
- x_emb = self.frame_cnn(obs)
79
- act_emb = rearrange(self.act_emb(act), 'b t e -> b t 1 e')
80
- q_emb = self.latents_emb(latent_tokens)
81
- x_a_q = rearrange(torch.cat((x_emb[:, :-1], act_emb, q_emb), dim=2), 'b t k2 e -> b (t k2) e')
82
- wm_input_sequence = torch.cat((x_a_q, x_emb[:, -1]), dim=1)
83
- wm_output_sequence = self(wm_input_sequence, use_kv_cache=use_kv_cache).output_sequence
84
-
85
- return wm_output_sequence
 
70
  "logits_rewards": logits_rewards,
71
  "logits_ends": logits_ends
72
  }