ShaswatRobotics commited on
Commit
ae0f181
·
verified ·
1 Parent(s): 8635508

Update iris/src/tokenizer.py

Browse files
Files changed (1) hide show
  1. iris/src/tokenizer.py +6 -0
iris/src/tokenizer.py CHANGED
@@ -68,6 +68,12 @@ class Tokenizer(nn.Module):
68
  rec = self.postprocess_output(rec)
69
  return rec
70
 
 
 
 
 
 
 
71
  @torch.no_grad()
72
  def encode_decode(self, x: torch.Tensor, should_preprocess: bool = False, should_postprocess: bool = False) -> torch.Tensor:
73
  z_q = self.encode(x, should_preprocess).z_quantized
 
68
  rec = self.postprocess_output(rec)
69
  return rec
70
 
71
+ def decode_obs_tokens(self, obs_tokens, num_observations_tokens):
72
+ embedded_tokens = self.embedding(self.obs_tokens) # (B, K, E)
73
+ z = rearrange(embedded_tokens, 'b (h w) e -> b e h w', h=int(np.sqrt(num_observations_tokens)))
74
+ rec = self.decode(z, should_postprocess=True) # (B, C, H, W)
75
+ return torch.clamp(rec, 0, 1)
76
+
77
  @torch.no_grad()
78
  def encode_decode(self, x: torch.Tensor, should_preprocess: bool = False, should_postprocess: bool = False) -> torch.Tensor:
79
  z_q = self.encode(x, should_preprocess).z_quantized