ShaswatRobotics commited on
Commit
5bd6809
·
verified ·
1 Parent(s): 6db0f6c

Update delta-iris/src/models/quantizer.py

Browse files
Files changed (1) hide show
  1. delta-iris/src/models/quantizer.py +3 -3
delta-iris/src/models/quantizer.py CHANGED
@@ -15,9 +15,9 @@ class Quantizer(nn.Module):
15
  self.pre_quant_proj = nn.Linear(input_dim, codebook_dim)
16
  self.post_quant_proj = nn.Linear(codebook_dim, input_dim)
17
  codebook = torch.empty(codebook_size, codebook_dim, requires_grad=False).uniform_(-1.0 / codebook_size, 1.0 / codebook_size)
18
- self.num_codebook_updates = torch.tensor(0)
19
- self.codebook = codebook
20
- self.codewords_freqs = torch.ones(codebook_size).div(codebook_size)
21
 
22
  def forward(self, z: torch.Tensor) -> dict:
23
  z = self.pre_quant_proj(z)
 
15
  self.pre_quant_proj = nn.Linear(input_dim, codebook_dim)
16
  self.post_quant_proj = nn.Linear(codebook_dim, input_dim)
17
  codebook = torch.empty(codebook_size, codebook_dim, requires_grad=False).uniform_(-1.0 / codebook_size, 1.0 / codebook_size)
18
+ self.register_buffer('num_codebook_updates', torch.tensor(0))
19
+ self.register_buffer('codebook', codebook)
20
+ self.register_buffer('codewords_freqs', torch.ones(codebook_size).div(codebook_size))
21
 
22
  def forward(self, z: torch.Tensor) -> dict:
23
  z = self.pre_quant_proj(z)