Update delta-iris/src/models/tokenizer/quantizer.py
Browse files
delta-iris/src/models/tokenizer/quantizer.py
CHANGED
|
@@ -20,7 +20,7 @@ class Quantizer(nn.Module):
|
|
| 20 |
self.register_buffer('codebook', codebook)
|
| 21 |
self.register_buffer('codewords_freqs', torch.ones(codebook_size).div(codebook_size))
|
| 22 |
|
| 23 |
-
def forward(self, z: torch.Tensor) ->
|
| 24 |
z = self.pre_quant_proj(z)
|
| 25 |
z = F.normalize(z, dim=-1)
|
| 26 |
b, k = z.size(0), z.size(2)
|
|
|
|
| 20 |
self.register_buffer('codebook', codebook)
|
| 21 |
self.register_buffer('codewords_freqs', torch.ones(codebook_size).div(codebook_size))
|
| 22 |
|
| 23 |
+
def forward(self, z: torch.Tensor) -> dict:
|
| 24 |
z = self.pre_quant_proj(z)
|
| 25 |
z = F.normalize(z, dim=-1)
|
| 26 |
b, k = z.size(0), z.size(2)
|