klemenk commited on
Commit
a546ccb
·
verified ·
1 Parent(s): 1412e6f

Update modeling_wavcoch.py

Browse files
Files changed (1) hide show
  1. modeling_wavcoch.py +2 -2
modeling_wavcoch.py CHANGED
@@ -508,7 +508,7 @@ class WavCoch(PreTrainedModel):
508
  if pad:
509
  wav = F.pad(wav, (self.N - self.hop_length, 0), mode="constant", value=0.0)
510
 
511
- codes = self.quantize(wav, pad=False).long() # already padded
512
  return BatchEncoding({"input_values": codes, "input_ids": codes})
513
 
514
  # Training / reconstruction mode
@@ -542,7 +542,7 @@ class WavCoch(PreTrainedModel):
542
  x = real_part + imag_part
543
  x = self.encoder(x).permute(0, 2, 1) # (B, T, D)
544
  _, indices = self.quantizer(x) # (B, T)
545
- return indices
546
 
547
  @torch.no_grad()
548
  def decode(self, indices: torch.Tensor) -> torch.Tensor:
 
508
  if pad:
509
  wav = F.pad(wav, (self.N - self.hop_length, 0), mode="constant", value=0.0)
510
 
511
+ codes = self.quantize(wav, pad=False) # already padded
512
  return BatchEncoding({"input_values": codes, "input_ids": codes})
513
 
514
  # Training / reconstruction mode
 
542
  x = real_part + imag_part
543
  x = self.encoder(x).permute(0, 2, 1) # (B, T, D)
544
  _, indices = self.quantizer(x) # (B, T)
545
+ return indices.long()
546
 
547
  @torch.no_grad()
548
  def decode(self, indices: torch.Tensor) -> torch.Tensor: