ShaswatRobotics commited on
Commit
cb06601
·
verified ·
1 Parent(s): b7dc1ea

Update iris/src/world_model.py

Browse files
Files changed (1) hide show
  1. iris/src/world_model.py +2 -2
iris/src/world_model.py CHANGED
@@ -27,7 +27,7 @@ class WorldModel(nn.Module):
27
  self.embedder = Embedder(
28
  max_blocks=config["max_blocks"],
29
  block_masks=[act_tokens_pattern, obs_tokens_pattern],
30
- embedding_tables=nn.ModuleList([nn.Embedding(act_vocab_size, config["embed_dim"]), nn.Embedding(obs_vocab_size, config["embed_dim"])])
31
  )
32
 
33
  self.head_observations = Head(
@@ -36,7 +36,7 @@ class WorldModel(nn.Module):
36
  head_module=nn.Sequential(
37
  nn.Linear(config["embed_dim"], config["embed_dim"]),
38
  nn.ReLU(),
39
- nn.Linear(config["embed_dim"], obs_vocab_size)
40
  )
41
  )
42
 
 
27
  self.embedder = Embedder(
28
  max_blocks=config["max_blocks"],
29
  block_masks=[act_tokens_pattern, obs_tokens_pattern],
30
+ embedding_tables=nn.ModuleList([nn.Embedding(self.act_vocab_size, config["embed_dim"]), nn.Embedding(self.obs_vocab_size, config["embed_dim"])])
31
  )
32
 
33
  self.head_observations = Head(
 
36
  head_module=nn.Sequential(
37
  nn.Linear(config["embed_dim"], config["embed_dim"]),
38
  nn.ReLU(),
39
+ nn.Linear(config["embed_dim"], self.obs_vocab_size)
40
  )
41
  )
42