Update delta-iris/src/models/quantizer.py
Browse files
delta-iris/src/models/quantizer.py
CHANGED
|
@@ -15,9 +15,9 @@ class Quantizer(nn.Module):
|
|
| 15 |
self.pre_quant_proj = nn.Linear(input_dim, codebook_dim)
|
| 16 |
self.post_quant_proj = nn.Linear(codebook_dim, input_dim)
|
| 17 |
codebook = torch.empty(codebook_size, codebook_dim, requires_grad=False).uniform_(-1.0 / codebook_size, 1.0 / codebook_size)
|
| 18 |
-
self.num_codebook_updates
|
| 19 |
-
self.codebook
|
| 20 |
-
self.codewords_freqs
|
| 21 |
|
| 22 |
def forward(self, z: torch.Tensor) -> dict:
|
| 23 |
z = self.pre_quant_proj(z)
|
|
|
|
| 15 |
self.pre_quant_proj = nn.Linear(input_dim, codebook_dim)
|
| 16 |
self.post_quant_proj = nn.Linear(codebook_dim, input_dim)
|
| 17 |
codebook = torch.empty(codebook_size, codebook_dim, requires_grad=False).uniform_(-1.0 / codebook_size, 1.0 / codebook_size)
|
| 18 |
+
self.register_buffer('num_codebook_updates', torch.tensor(0))
|
| 19 |
+
self.register_buffer('codebook', codebook)
|
| 20 |
+
self.register_buffer('codewords_freqs', torch.ones(codebook_size).div(codebook_size))
|
| 21 |
|
| 22 |
def forward(self, z: torch.Tensor) -> dict:
|
| 23 |
z = self.pre_quant_proj(z)
|