ShaswatRobotics commited on
Commit
a9b48d6
·
verified ·
1 Parent(s): 6fdbffe

Update delta-iris/src/tokenizer.py

Browse files
Files changed (1) hide show
  1. delta-iris/src/tokenizer.py +16 -0
delta-iris/src/tokenizer.py CHANGED
@@ -64,6 +64,22 @@ class Tokenizer(nn.Module):
64
 
65
  return r
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  @torch.no_grad()
68
  def burn_in(self, obs: torch.FloatTensor, act: torch.LongTensor) -> torch.LongTensor:
69
  assert obs.size(1) == act.size(1) + 1
 
64
 
65
  return r
66
 
67
+ def embed_tokens(self, tokens):
68
+ q = self.quantizer.embed_tokens(tokens)
69
+ b, t, hw, kle = q.shape
70
+
71
+ h = self.tokens_grid_res
72
+ w = self.tokens_grid_res
73
+ k = self.token_res
74
+ l = self.token_res
75
+ e = kle // (k * l)
76
+
77
+ q = q.reshape(b, t, h, w, k, l, e)
78
+ q = q.transpose(0, 1, 6, 2, 4, 3, 5)
79
+ q = q.reshape(b, t, e, h * k, w * l)
80
+
81
+ return q
82
+
83
  @torch.no_grad()
84
  def burn_in(self, obs: torch.FloatTensor, act: torch.LongTensor) -> torch.LongTensor:
85
  assert obs.size(1) == act.size(1) + 1