Upload 8 files
Browse files
delta-iris/src/models/convnet.py
CHANGED
|
@@ -1,11 +1,7 @@
|
|
| 1 |
-
from dataclasses import dataclass
|
| 2 |
-
from typing import List
|
| 3 |
-
|
| 4 |
-
from einops import rearrange
|
| 5 |
import torch
|
| 6 |
import torch.nn as nn
|
| 7 |
import torch.nn.functional as F
|
| 8 |
-
|
| 9 |
|
| 10 |
class FrameEncoder(nn.Module):
|
| 11 |
def __init__(self, config: dict) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
import torch.nn as nn
|
| 3 |
import torch.nn.functional as F
|
| 4 |
+
from einops import rearrange
|
| 5 |
|
| 6 |
class FrameEncoder(nn.Module):
|
| 7 |
def __init__(self, config: dict) -> None:
|
delta-iris/src/models/quantizer.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import Optional
|
| 3 |
+
from einops import rearrange
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
class Quantizer(nn.Module):
|
| 10 |
+
def __init__(self, codebook_size: int, codebook_dim: int, input_dim: int, max_codebook_updates_with_revival: Optional[int] = None) -> None:
|
| 11 |
+
super().__init__()
|
| 12 |
+
assert math.log2(codebook_size).is_integer()
|
| 13 |
+
self.revival_entropy_threshold = int(math.log2(codebook_size)) - 2
|
| 14 |
+
self.max_codebook_updates_with_revival = max_codebook_updates_with_revival
|
| 15 |
+
self.pre_quant_proj = nn.Linear(input_dim, codebook_dim)
|
| 16 |
+
self.post_quant_proj = nn.Linear(codebook_dim, input_dim)
|
| 17 |
+
codebook = torch.empty(codebook_size, codebook_dim, requires_grad=False).uniform_(-1.0 / codebook_size, 1.0 / codebook_size)
|
| 18 |
+
self.num_codebook_updates = torch.tensor(0)
|
| 19 |
+
self.codebook = codebook
|
| 20 |
+
self.codewords_freqs = torch.ones(codebook_size).div(codebook_size)
|
| 21 |
+
|
| 22 |
+
def forward(self, z: torch.Tensor) -> dict:
|
| 23 |
+
z = self.pre_quant_proj(z)
|
| 24 |
+
z = F.normalize(z, dim=-1)
|
| 25 |
+
b, k = z.size(0), z.size(2)
|
| 26 |
+
z = rearrange(z, 'b t k e -> (b t k) e')
|
| 27 |
+
|
| 28 |
+
cosine_similarity = torch.einsum('n e, c e -> n c', z, self.codebook)
|
| 29 |
+
tokens = cosine_similarity.argmax(dim=-1)
|
| 30 |
+
q = self.codebook[tokens]
|
| 31 |
+
|
| 32 |
+
q = z + (q - z).detach()
|
| 33 |
+
q = self.post_quant_proj(q)
|
| 34 |
+
|
| 35 |
+
q = rearrange(q, '(b t k) e -> b t k e', b=b, k=k)
|
| 36 |
+
tokens = rearrange(tokens, '(b t k) -> b t k', b=b, k=k)
|
| 37 |
+
return {
|
| 38 |
+
"q": q,
|
| 39 |
+
"tokens": tokens,
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
def compute_codebook_entropy(self) -> float:
|
| 43 |
+
probs = self.codewords_freqs[self.codewords_freqs != 0]
|
| 44 |
+
return -(torch.log2(probs) * probs).sum().item()
|
| 45 |
+
|
| 46 |
+
@torch.no_grad()
|
| 47 |
+
def embed_tokens(self, tokens: torch.LongTensor) -> torch.FloatTensor:
|
| 48 |
+
return self.post_quant_proj(self.codebook[tokens])
|
delta-iris/src/models/transformer.py
CHANGED
|
@@ -2,10 +2,9 @@
|
|
| 2 |
Inspired from https://github.com/karpathy/minGPT
|
| 3 |
"""
|
| 4 |
|
| 5 |
-
from dataclasses import dataclass
|
| 6 |
from typing import Optional
|
| 7 |
-
|
| 8 |
from einops import rearrange
|
|
|
|
| 9 |
import torch
|
| 10 |
import torch.nn as nn
|
| 11 |
|
|
|
|
| 2 |
Inspired from https://github.com/karpathy/minGPT
|
| 3 |
"""
|
| 4 |
|
|
|
|
| 5 |
from typing import Optional
|
|
|
|
| 6 |
from einops import rearrange
|
| 7 |
+
|
| 8 |
import torch
|
| 9 |
import torch.nn as nn
|
| 10 |
|
delta-iris/src/tokenizer.py
CHANGED
|
@@ -1,14 +1,10 @@
|
|
| 1 |
-
from dataclasses import dataclass
|
| 2 |
import math
|
| 3 |
-
from typing import Dict, Tuple
|
| 4 |
-
|
| 5 |
from einops import rearrange
|
| 6 |
import torch
|
| 7 |
import torch.nn as nn
|
| 8 |
|
| 9 |
from .models.convnet import FrameEncoder, FrameDecoder
|
| 10 |
-
from .models.
|
| 11 |
-
from .models.utils import init_weights
|
| 12 |
|
| 13 |
class Tokenizer(nn.Module):
|
| 14 |
def __init__(self, config: dict) -> None:
|
|
@@ -32,8 +28,6 @@ class Tokenizer(nn.Module):
|
|
| 32 |
self.decoder = FrameDecoder(config["decoder_config"])
|
| 33 |
self.frame_cnn = FrameEncoder(config["frame_cnn_config"])
|
| 34 |
|
| 35 |
-
self.apply(init_weights)
|
| 36 |
-
|
| 37 |
def __repr__(self) -> str:
|
| 38 |
return "tokenizer"
|
| 39 |
|
|
|
|
|
|
|
| 1 |
import math
|
|
|
|
|
|
|
| 2 |
from einops import rearrange
|
| 3 |
import torch
|
| 4 |
import torch.nn as nn
|
| 5 |
|
| 6 |
from .models.convnet import FrameEncoder, FrameDecoder
|
| 7 |
+
from .models.quantizer import Quantizer
|
|
|
|
| 8 |
|
| 9 |
class Tokenizer(nn.Module):
|
| 10 |
def __init__(self, config: dict) -> None:
|
|
|
|
| 28 |
self.decoder = FrameDecoder(config["decoder_config"])
|
| 29 |
self.frame_cnn = FrameEncoder(config["frame_cnn_config"])
|
| 30 |
|
|
|
|
|
|
|
| 31 |
def __repr__(self) -> str:
|
| 32 |
return "tokenizer"
|
| 33 |
|
delta-iris/src/world_model.py
CHANGED
|
@@ -1,15 +1,11 @@
|
|
| 1 |
-
from
|
| 2 |
-
|
| 3 |
-
from einops import rearrange, repeat
|
| 4 |
from einops.layers.torch import Rearrange
|
| 5 |
import torch
|
| 6 |
import torch.nn as nn
|
| 7 |
-
import torch.nn.functional as F
|
| 8 |
|
| 9 |
from .models.convnet import FrameEncoder
|
| 10 |
from .models.slicer import Head
|
| 11 |
from .models.transformer import TransformerEncoder
|
| 12 |
-
from .models.utils import init_weights
|
| 13 |
|
| 14 |
class WorldModel(nn.Module):
|
| 15 |
def __init__(self, config: dict) -> None:
|
|
@@ -55,8 +51,6 @@ class WorldModel(nn.Module):
|
|
| 55 |
)
|
| 56 |
)
|
| 57 |
|
| 58 |
-
self.apply(init_weights)
|
| 59 |
-
|
| 60 |
def __repr__(self) -> str:
|
| 61 |
return "world_model"
|
| 62 |
|
|
|
|
| 1 |
+
from einops import rearrange
|
|
|
|
|
|
|
| 2 |
from einops.layers.torch import Rearrange
|
| 3 |
import torch
|
| 4 |
import torch.nn as nn
|
|
|
|
| 5 |
|
| 6 |
from .models.convnet import FrameEncoder
|
| 7 |
from .models.slicer import Head
|
| 8 |
from .models.transformer import TransformerEncoder
|
|
|
|
| 9 |
|
| 10 |
class WorldModel(nn.Module):
|
| 11 |
def __init__(self, config: dict) -> None:
|
|
|
|
| 51 |
)
|
| 52 |
)
|
| 53 |
|
|
|
|
|
|
|
| 54 |
def __repr__(self) -> str:
|
| 55 |
return "world_model"
|
| 56 |
|