ShaswatRobotics commited on
Commit
a30bf83
·
verified ·
1 Parent(s): 29d447d

Update iris/src/world_model.py

Browse files
Files changed (1) hide show
  1. iris/src/world_model.py +2 -2
iris/src/world_model.py CHANGED
@@ -84,9 +84,9 @@ class WorldModel(nn.Module):
84
 
85
  }
86
 
87
- def generate_empty_keys_values(self):
88
 
89
- values = self.world_model.transformer.generate_empty_keys_values(n=1, max_tokens= self.config["max_tokens"])
90
 
91
  def compute_labels_world_model(self, obs_tokens: torch.Tensor, rewards: torch.Tensor, ends: torch.Tensor, mask_padding: torch.BoolTensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
92
  assert torch.all(ends.sum(dim=1) <= 1) # at most 1 done
 
84
 
85
  }
86
 
87
+ def generate_empty_keys_values(self, n= 1):
88
 
89
+ values = self.world_model.transformer.generate_empty_keys_values(n=n, max_tokens= self.config["max_tokens"])
90
 
91
  def compute_labels_world_model(self, obs_tokens: torch.Tensor, rewards: torch.Tensor, ends: torch.Tensor, mask_padding: torch.BoolTensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
92
  assert torch.all(ends.sum(dim=1) <= 1) # at most 1 done