ShaswatRobotics commited on
Commit
a416a2a
·
verified ·
1 Parent(s): cb5dd2f

Update iris/src/tokenizer.py

Browse files
Files changed (1) hide show
  1. iris/src/tokenizer.py +1 -1
iris/src/tokenizer.py CHANGED
@@ -69,7 +69,7 @@ class Tokenizer(nn.Module):
69
  return rec
70
 
71
  def decode_obs_tokens(self, obs_tokens, num_observations_tokens):
72
- embedded_tokens = self.embedding(self.obs_tokens) # (B, K, E)
73
  z = rearrange(embedded_tokens, 'b (h w) e -> b e h w', h=int(np.sqrt(num_observations_tokens)))
74
  rec = self.decode(z, should_postprocess=True) # (B, C, H, W)
75
  return torch.clamp(rec, 0, 1)
 
69
  return rec
70
 
71
  def decode_obs_tokens(self, obs_tokens, num_observations_tokens):
72
+ embedded_tokens = self.embedding(obs_tokens) # (B, K, E)
73
  z = rearrange(embedded_tokens, 'b (h w) e -> b e h w', h=int(np.sqrt(num_observations_tokens)))
74
  rec = self.decode(z, should_postprocess=True) # (B, C, H, W)
75
  return torch.clamp(rec, 0, 1)