Update iris/src/tokenizer.py
Browse files- iris/src/tokenizer.py +2 -2
iris/src/tokenizer.py
CHANGED
|
@@ -29,9 +29,9 @@ class Tokenizer(nn.Module):
|
|
| 29 |
|
| 30 |
def forward(self, x: torch.Tensor, should_preprocess: bool = False, should_postprocess: bool = False) -> Tuple[torch.Tensor]:
|
| 31 |
outputs = self.encode(x, should_preprocess)
|
| 32 |
-
decoder_input = outputs
|
| 33 |
reconstructions = self.decode(decoder_input, should_postprocess)
|
| 34 |
-
return outputs
|
| 35 |
|
| 36 |
def encode(self, x: torch.Tensor, should_preprocess: bool = False) -> dict:
|
| 37 |
if should_preprocess:
|
|
|
|
| 29 |
|
| 30 |
def forward(self, x: torch.Tensor, should_preprocess: bool = False, should_postprocess: bool = False) -> Tuple[torch.Tensor]:
|
| 31 |
outputs = self.encode(x, should_preprocess)
|
| 32 |
+
decoder_input = outputs["z"] + (outputs["z_quantized"] - outputs["z"]).detach()
|
| 33 |
reconstructions = self.decode(decoder_input, should_postprocess)
|
| 34 |
+
return outputs["z"], outputs["z_quantized"], reconstructions
|
| 35 |
|
| 36 |
def encode(self, x: torch.Tensor, should_preprocess: bool = False) -> dict:
|
| 37 |
if should_preprocess:
|