Update delta-iris/src/tokenizer.py
Browse files- delta-iris/src/tokenizer.py +16 -0
delta-iris/src/tokenizer.py
CHANGED
|
@@ -64,6 +64,22 @@ class Tokenizer(nn.Module):
|
|
| 64 |
|
| 65 |
return r
|
| 66 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
@torch.no_grad()
|
| 68 |
def burn_in(self, obs: torch.FloatTensor, act: torch.LongTensor) -> torch.LongTensor:
|
| 69 |
assert obs.size(1) == act.size(1) + 1
|
|
|
|
| 64 |
|
| 65 |
return r
|
| 66 |
|
| 67 |
+
def embed_tokens(self, tokens):
|
| 68 |
+
q = self.quantizer.embed_tokens(tokens)
|
| 69 |
+
b, t, hw, kle = q.shape
|
| 70 |
+
|
| 71 |
+
h = self.tokens_grid_res
|
| 72 |
+
w = self.tokens_grid_res
|
| 73 |
+
k = self.token_res
|
| 74 |
+
l = self.token_res
|
| 75 |
+
e = kle // (k * l)
|
| 76 |
+
|
| 77 |
+
q = q.reshape(b, t, h, w, k, l, e)
|
| 78 |
+
q = q.transpose(0, 1, 6, 2, 4, 3, 5)
|
| 79 |
+
q = q.reshape(b, t, e, h * k, w * l)
|
| 80 |
+
|
| 81 |
+
return q
|
| 82 |
+
|
| 83 |
@torch.no_grad()
|
| 84 |
def burn_in(self, obs: torch.FloatTensor, act: torch.LongTensor) -> torch.LongTensor:
|
| 85 |
assert obs.size(1) == act.size(1) + 1
|