ShaswatRobotics commited on
Commit
f9f6093
·
verified ·
1 Parent(s): 7c81dfd

Upload 8 files

Browse files
delta-iris/src/models/convnet.py CHANGED
@@ -1,11 +1,7 @@
1
- from dataclasses import dataclass
2
- from typing import List
3
-
4
- from einops import rearrange
5
  import torch
6
  import torch.nn as nn
7
  import torch.nn.functional as F
8
-
9
 
10
  class FrameEncoder(nn.Module):
11
  def __init__(self, config: dict) -> None:
 
 
 
 
 
1
  import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
+ from einops import rearrange
5
 
6
  class FrameEncoder(nn.Module):
7
  def __init__(self, config: dict) -> None:
delta-iris/src/models/quantizer.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional
3
+ from einops import rearrange
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ class Quantizer(nn.Module):
10
+ def __init__(self, codebook_size: int, codebook_dim: int, input_dim: int, max_codebook_updates_with_revival: Optional[int] = None) -> None:
11
+ super().__init__()
12
+ assert math.log2(codebook_size).is_integer()
13
+ self.revival_entropy_threshold = int(math.log2(codebook_size)) - 2
14
+ self.max_codebook_updates_with_revival = max_codebook_updates_with_revival
15
+ self.pre_quant_proj = nn.Linear(input_dim, codebook_dim)
16
+ self.post_quant_proj = nn.Linear(codebook_dim, input_dim)
17
+ codebook = torch.empty(codebook_size, codebook_dim, requires_grad=False).uniform_(-1.0 / codebook_size, 1.0 / codebook_size)
18
+ self.num_codebook_updates = torch.tensor(0)
19
+ self.codebook = codebook
20
+ self.codewords_freqs = torch.ones(codebook_size).div(codebook_size)
21
+
22
+ def forward(self, z: torch.Tensor) -> dict:
23
+ z = self.pre_quant_proj(z)
24
+ z = F.normalize(z, dim=-1)
25
+ b, k = z.size(0), z.size(2)
26
+ z = rearrange(z, 'b t k e -> (b t k) e')
27
+
28
+ cosine_similarity = torch.einsum('n e, c e -> n c', z, self.codebook)
29
+ tokens = cosine_similarity.argmax(dim=-1)
30
+ q = self.codebook[tokens]
31
+
32
+ q = z + (q - z).detach()
33
+ q = self.post_quant_proj(q)
34
+
35
+ q = rearrange(q, '(b t k) e -> b t k e', b=b, k=k)
36
+ tokens = rearrange(tokens, '(b t k) -> b t k', b=b, k=k)
37
+ return {
38
+ "q": q,
39
+ "tokens": tokens,
40
+ }
41
+
42
+ def compute_codebook_entropy(self) -> float:
43
+ probs = self.codewords_freqs[self.codewords_freqs != 0]
44
+ return -(torch.log2(probs) * probs).sum().item()
45
+
46
+ @torch.no_grad()
47
+ def embed_tokens(self, tokens: torch.LongTensor) -> torch.FloatTensor:
48
+ return self.post_quant_proj(self.codebook[tokens])
delta-iris/src/models/transformer.py CHANGED
@@ -2,10 +2,9 @@
2
  Inspired from https://github.com/karpathy/minGPT
3
  """
4
 
5
- from dataclasses import dataclass
6
  from typing import Optional
7
-
8
  from einops import rearrange
 
9
  import torch
10
  import torch.nn as nn
11
 
 
2
  Inspired from https://github.com/karpathy/minGPT
3
  """
4
 
 
5
  from typing import Optional
 
6
  from einops import rearrange
7
+
8
  import torch
9
  import torch.nn as nn
10
 
delta-iris/src/tokenizer.py CHANGED
@@ -1,14 +1,10 @@
1
- from dataclasses import dataclass
2
  import math
3
- from typing import Dict, Tuple
4
-
5
  from einops import rearrange
6
  import torch
7
  import torch.nn as nn
8
 
9
  from .models.convnet import FrameEncoder, FrameDecoder
10
- from .models.tokenizer.quantizer import Quantizer
11
- from .models.utils import init_weights
12
 
13
  class Tokenizer(nn.Module):
14
  def __init__(self, config: dict) -> None:
@@ -32,8 +28,6 @@ class Tokenizer(nn.Module):
32
  self.decoder = FrameDecoder(config["decoder_config"])
33
  self.frame_cnn = FrameEncoder(config["frame_cnn_config"])
34
 
35
- self.apply(init_weights)
36
-
37
  def __repr__(self) -> str:
38
  return "tokenizer"
39
 
 
 
1
  import math
 
 
2
  from einops import rearrange
3
  import torch
4
  import torch.nn as nn
5
 
6
  from .models.convnet import FrameEncoder, FrameDecoder
7
+ from .models.quantizer import Quantizer
 
8
 
9
  class Tokenizer(nn.Module):
10
  def __init__(self, config: dict) -> None:
 
28
  self.decoder = FrameDecoder(config["decoder_config"])
29
  self.frame_cnn = FrameEncoder(config["frame_cnn_config"])
30
 
 
 
31
  def __repr__(self) -> str:
32
  return "tokenizer"
33
 
delta-iris/src/world_model.py CHANGED
@@ -1,15 +1,11 @@
1
- from dataclasses import dataclass
2
-
3
- from einops import rearrange, repeat
4
  from einops.layers.torch import Rearrange
5
  import torch
6
  import torch.nn as nn
7
- import torch.nn.functional as F
8
 
9
  from .models.convnet import FrameEncoder
10
  from .models.slicer import Head
11
  from .models.transformer import TransformerEncoder
12
- from .models.utils import init_weights
13
 
14
  class WorldModel(nn.Module):
15
  def __init__(self, config: dict) -> None:
@@ -55,8 +51,6 @@ class WorldModel(nn.Module):
55
  )
56
  )
57
 
58
- self.apply(init_weights)
59
-
60
  def __repr__(self) -> str:
61
  return "world_model"
62
 
 
1
+ from einops import rearrange
 
 
2
  from einops.layers.torch import Rearrange
3
  import torch
4
  import torch.nn as nn
 
5
 
6
  from .models.convnet import FrameEncoder
7
  from .models.slicer import Head
8
  from .models.transformer import TransformerEncoder
 
9
 
10
  class WorldModel(nn.Module):
11
  def __init__(self, config: dict) -> None:
 
51
  )
52
  )
53
 
 
 
54
  def __repr__(self) -> str:
55
  return "world_model"
56