ShaswatRobotics commited on
Commit
72553e8
·
verified ·
1 Parent(s): 3e9b1c7

Update delta-iris/src/world_model.py

Browse files
Files changed (1) hide show
  1. delta-iris/src/world_model.py +6 -0
delta-iris/src/world_model.py CHANGED
@@ -54,6 +54,12 @@ class WorldModel(nn.Module):
54
  def __repr__(self) -> str:
55
  return "world_model"
56
 
 
 
 
 
 
 
57
  def forward(self, sequence: torch.FloatTensor, use_kv_cache: bool = False) -> dict:
58
  prev_steps = self.transformer.keys_values.size if use_kv_cache else 0
59
  num_steps = sequence.size(1)
 
54
  def __repr__(self) -> str:
55
  return "world_model"
56
 
57
+ def blocks_left_in_kv_cache(self):
58
+ return self.transformer.num_blocks_left_in_kv_cache
59
+
60
+ def reset_kv_cache(self):
61
+ transformer.reset_kv_cache(n=1)
62
+
63
  def forward(self, sequence: torch.FloatTensor, use_kv_cache: bool = False) -> dict:
64
  prev_steps = self.transformer.keys_values.size if use_kv_cache else 0
65
  num_steps = sequence.size(1)