Update delta-iris/src/world_model.py
Browse files
delta-iris/src/world_model.py
CHANGED
|
@@ -54,6 +54,12 @@ class WorldModel(nn.Module):
|
|
| 54 |
def __repr__(self) -> str:
|
| 55 |
return "world_model"
|
| 56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
def forward(self, sequence: torch.FloatTensor, use_kv_cache: bool = False) -> dict:
|
| 58 |
prev_steps = self.transformer.keys_values.size if use_kv_cache else 0
|
| 59 |
num_steps = sequence.size(1)
|
|
|
|
| 54 |
def __repr__(self) -> str:
|
| 55 |
return "world_model"
|
| 56 |
|
| 57 |
+
def blocks_left_in_kv_cache(self):
|
| 58 |
+
return self.transformer.num_blocks_left_in_kv_cache
|
| 59 |
+
|
| 60 |
+
def reset_kv_cache(self):
|
| 61 |
+
transformer.reset_kv_cache(n=1)
|
| 62 |
+
|
| 63 |
def forward(self, sequence: torch.FloatTensor, use_kv_cache: bool = False) -> dict:
|
| 64 |
prev_steps = self.transformer.keys_values.size if use_kv_cache else 0
|
| 65 |
num_steps = sequence.size(1)
|