Update modeling_wavcoch.py
Browse files- modeling_wavcoch.py +2 -2
modeling_wavcoch.py
CHANGED
|
@@ -508,7 +508,7 @@ class WavCoch(PreTrainedModel):
|
|
| 508 |
if pad:
|
| 509 |
wav = F.pad(wav, (self.N - self.hop_length, 0), mode="constant", value=0.0)
|
| 510 |
|
| 511 |
-
codes = self.quantize(wav, pad=False)
|
| 512 |
return BatchEncoding({"input_values": codes, "input_ids": codes})
|
| 513 |
|
| 514 |
# Training / reconstruction mode
|
|
@@ -542,7 +542,7 @@ class WavCoch(PreTrainedModel):
|
|
| 542 |
x = real_part + imag_part
|
| 543 |
x = self.encoder(x).permute(0, 2, 1) # (B, T, D)
|
| 544 |
_, indices = self.quantizer(x) # (B, T)
|
| 545 |
-
return indices
|
| 546 |
|
| 547 |
@torch.no_grad()
|
| 548 |
def decode(self, indices: torch.Tensor) -> torch.Tensor:
|
|
|
|
| 508 |
if pad:
|
| 509 |
wav = F.pad(wav, (self.N - self.hop_length, 0), mode="constant", value=0.0)
|
| 510 |
|
| 511 |
+
codes = self.quantize(wav, pad=False) # already padded
|
| 512 |
return BatchEncoding({"input_values": codes, "input_ids": codes})
|
| 513 |
|
| 514 |
# Training / reconstruction mode
|
|
|
|
| 542 |
x = real_part + imag_part
|
| 543 |
x = self.encoder(x).permute(0, 2, 1) # (B, T, D)
|
| 544 |
_, indices = self.quantizer(x) # (B, T)
|
| 545 |
+
return indices.long()
|
| 546 |
|
| 547 |
@torch.no_grad()
|
| 548 |
def decode(self, indices: torch.Tensor) -> torch.Tensor:
|