Update iris/src/tokenizer.py
Browse files- iris/src/tokenizer.py +6 -0
iris/src/tokenizer.py
CHANGED
|
@@ -68,6 +68,12 @@ class Tokenizer(nn.Module):
|
|
| 68 |
rec = self.postprocess_output(rec)
|
| 69 |
return rec
|
| 70 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
@torch.no_grad()
|
| 72 |
def encode_decode(self, x: torch.Tensor, should_preprocess: bool = False, should_postprocess: bool = False) -> torch.Tensor:
|
| 73 |
z_q = self.encode(x, should_preprocess).z_quantized
|
|
|
|
| 68 |
rec = self.postprocess_output(rec)
|
| 69 |
return rec
|
| 70 |
|
| 71 |
+
def decode_obs_tokens(self, obs_tokens, num_observations_tokens):
|
| 72 |
+
embedded_tokens = self.embedding(self.obs_tokens) # (B, K, E)
|
| 73 |
+
z = rearrange(embedded_tokens, 'b (h w) e -> b e h w', h=int(np.sqrt(num_observations_tokens)))
|
| 74 |
+
rec = self.decode(z, should_postprocess=True) # (B, C, H, W)
|
| 75 |
+
return torch.clamp(rec, 0, 1)
|
| 76 |
+
|
| 77 |
@torch.no_grad()
|
| 78 |
def encode_decode(self, x: torch.Tensor, should_preprocess: bool = False, should_postprocess: bool = False) -> torch.Tensor:
|
| 79 |
z_q = self.encode(x, should_preprocess).z_quantized
|