Update iris/src/world_model.py
Browse files- iris/src/world_model.py +2 -2
iris/src/world_model.py
CHANGED
|
@@ -84,9 +84,9 @@ class WorldModel(nn.Module):
|
|
| 84 |
|
| 85 |
}
|
| 86 |
|
| 87 |
-
def generate_empty_keys_values(self):
|
| 88 |
|
| 89 |
-
values = self.world_model.transformer.generate_empty_keys_values(n=
|
| 90 |
|
| 91 |
def compute_labels_world_model(self, obs_tokens: torch.Tensor, rewards: torch.Tensor, ends: torch.Tensor, mask_padding: torch.BoolTensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 92 |
assert torch.all(ends.sum(dim=1) <= 1) # at most 1 done
|
|
|
|
| 84 |
|
| 85 |
}
|
| 86 |
|
| 87 |
+
def generate_empty_keys_values(self, n= 1):
|
| 88 |
|
| 89 |
+
values = self.world_model.transformer.generate_empty_keys_values(n=n, max_tokens= self.config["max_tokens"])
|
| 90 |
|
| 91 |
def compute_labels_world_model(self, obs_tokens: torch.Tensor, rewards: torch.Tensor, ends: torch.Tensor, mask_padding: torch.BoolTensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 92 |
assert torch.all(ends.sum(dim=1) <= 1) # at most 1 done
|