Update iris/src/tokenizer.py
Browse files- iris/src/tokenizer.py +1 -1
iris/src/tokenizer.py
CHANGED
|
@@ -69,7 +69,7 @@ class Tokenizer(nn.Module):
|
|
| 69 |
return rec
|
| 70 |
|
| 71 |
def decode_obs_tokens(self, obs_tokens, num_observations_tokens):
|
| 72 |
-
embedded_tokens = self.embedding(
|
| 73 |
z = rearrange(embedded_tokens, 'b (h w) e -> b e h w', h=int(np.sqrt(num_observations_tokens)))
|
| 74 |
rec = self.decode(z, should_postprocess=True) # (B, C, H, W)
|
| 75 |
return torch.clamp(rec, 0, 1)
|
|
|
|
| 69 |
return rec
|
| 70 |
|
| 71 |
def decode_obs_tokens(self, obs_tokens, num_observations_tokens):
|
| 72 |
+
embedded_tokens = self.embedding(obs_tokens) # (B, K, E)
|
| 73 |
z = rearrange(embedded_tokens, 'b (h w) e -> b e h w', h=int(np.sqrt(num_observations_tokens)))
|
| 74 |
rec = self.decode(z, should_postprocess=True) # (B, C, H, W)
|
| 75 |
return torch.clamp(rec, 0, 1)
|