ShaswatRobotics commited on
Commit
7d4a014
·
verified ·
1 Parent(s): ccf803e

Update iris/src/tokenizer.py

Browse files
Files changed (1) hide show
  1. iris/src/tokenizer.py +2 -2
iris/src/tokenizer.py CHANGED
@@ -29,9 +29,9 @@ class Tokenizer(nn.Module):
29
 
30
  def forward(self, x: torch.Tensor, should_preprocess: bool = False, should_postprocess: bool = False) -> Tuple[torch.Tensor]:
31
  outputs = self.encode(x, should_preprocess)
32
- decoder_input = outputs.z + (outputs.z_quantized - outputs.z).detach()
33
  reconstructions = self.decode(decoder_input, should_postprocess)
34
- return outputs.z, outputs.z_quantized, reconstructions
35
 
36
  def encode(self, x: torch.Tensor, should_preprocess: bool = False) -> dict:
37
  if should_preprocess:
 
29
 
30
  def forward(self, x: torch.Tensor, should_preprocess: bool = False, should_postprocess: bool = False) -> Tuple[torch.Tensor]:
31
  outputs = self.encode(x, should_preprocess)
32
+ decoder_input = outputs["z"] + (outputs["z_quantized"] - outputs["z"]).detach()
33
  reconstructions = self.decode(decoder_input, should_postprocess)
34
+ return outputs["z"], outputs["z_quantized"], reconstructions
35
 
36
  def encode(self, x: torch.Tensor, should_preprocess: bool = False) -> dict:
37
  if should_preprocess: