Update delta-iris/src/world_model.py
Browse files
delta-iris/src/world_model.py
CHANGED
|
@@ -58,7 +58,7 @@ class WorldModel(nn.Module):
|
|
| 58 |
return self.transformer.num_blocks_left_in_kv_cache
|
| 59 |
|
| 60 |
def reset_kv_cache(self):
|
| 61 |
-
transformer.reset_kv_cache(n=1)
|
| 62 |
|
| 63 |
def forward(self, sequence: torch.FloatTensor, use_kv_cache: bool = False) -> dict:
|
| 64 |
prev_steps = self.transformer.keys_values.size if use_kv_cache else 0
|
|
|
|
| 58 |
return self.transformer.num_blocks_left_in_kv_cache
|
| 59 |
|
| 60 |
def reset_kv_cache(self):
|
| 61 |
+
self.transformer.reset_kv_cache(n=1)
|
| 62 |
|
| 63 |
def forward(self, sequence: torch.FloatTensor, use_kv_cache: bool = False) -> dict:
|
| 64 |
prev_steps = self.transformer.keys_values.size if use_kv_cache else 0
|