Update iris/src/world_model.py
Browse files- iris/src/world_model.py +1 -1
iris/src/world_model.py
CHANGED
|
@@ -12,7 +12,7 @@ from .models.transformer import Transformer
|
|
| 12 |
class WorldModel(nn.Module):
|
| 13 |
def __init__(self, config: dict) -> None:
|
| 14 |
super().__init__()
|
| 15 |
-
self.obs_vocab_size, self.act_vocab_size = config["
|
| 16 |
self.config = config
|
| 17 |
self.transformer = Transformer(config)
|
| 18 |
|
|
|
|
| 12 |
class WorldModel(nn.Module):
|
| 13 |
def __init__(self, config: dict) -> None:
|
| 14 |
super().__init__()
|
| 15 |
+
self.obs_vocab_size, self.act_vocab_size = config["vocab_size"], config["act_vocab_size"]
|
| 16 |
self.config = config
|
| 17 |
self.transformer = Transformer(config)
|
| 18 |
|