Update iris/src/world_model.py
Browse files- iris/src/world_model.py +3 -3
iris/src/world_model.py
CHANGED
|
@@ -5,9 +5,9 @@ import torch
|
|
| 5 |
import torch.nn as nn
|
| 6 |
import torch.nn.functional as F
|
| 7 |
|
| 8 |
-
from models.kv_caching import KeysValues
|
| 9 |
-
from models.slicer import Embedder, Head
|
| 10 |
-
from models.transformer import Transformer
|
| 11 |
|
| 12 |
class WorldModel(nn.Module):
|
| 13 |
def __init__(self, obs_vocab_size: int, act_vocab_size: int, config: dict) -> None:
|
|
|
|
| 5 |
import torch.nn as nn
|
| 6 |
import torch.nn.functional as F
|
| 7 |
|
| 8 |
+
from .models.kv_caching import KeysValues
|
| 9 |
+
from .models.slicer import Embedder, Head
|
| 10 |
+
from .models.transformer import Transformer
|
| 11 |
|
| 12 |
class WorldModel(nn.Module):
|
| 13 |
def __init__(self, obs_vocab_size: int, act_vocab_size: int, config: dict) -> None:
|