ShaswatRobotics commited on
Commit
d1d30ce
·
verified ·
1 Parent(s): c32c7e6

Update iris/src/world_model.py

Browse files
Files changed (1) hide show
  1. iris/src/world_model.py +2 -2
iris/src/world_model.py CHANGED
@@ -10,9 +10,9 @@ from .models.slicer import Embedder, Head
10
  from .models.transformer import Transformer
11
 
12
  class WorldModel(nn.Module):
13
- def __init__(self, obs_vocab_size: int, act_vocab_size: int, config: dict) -> None:
14
  super().__init__()
15
- self.obs_vocab_size, self.act_vocab_size = obs_vocab_size, act_vocab_size
16
  self.config = config
17
  self.transformer = Transformer(config)
18
 
 
10
  from .models.transformer import Transformer
11
 
12
  class WorldModel(nn.Module):
13
+ def __init__(self, config: dict) -> None:
14
  super().__init__()
15
+ self.obs_vocab_size, self.act_vocab_size = config["obs_vocab_size"], config["act_vocab_size"]
16
  self.config = config
17
  self.transformer = Transformer(config)
18