fix bugs
Browse files
feature_extraction_xy_tokenizer.py
CHANGED
|
@@ -82,8 +82,8 @@ class ExtractorIterator:
|
|
| 82 |
x_padded = F.pad(x, (0, padding_size), "constant", 0)
|
| 83 |
output_tensor = x_padded.unfold(dimension=2, size=kernel, step=stride).squeeze(0).transpose(0, 1)
|
| 84 |
output_lengths = torch.full((num_chunks,), kernel, dtype=torch.long)
|
| 85 |
-
|
| 86 |
-
output_lengths[
|
| 87 |
output_seq_no = torch.full((num_chunks,), seq_no, dtype=torch.long)
|
| 88 |
return output_tensor, output_lengths, output_seq_no
|
| 89 |
|
|
@@ -118,6 +118,7 @@ class ExtractorIterator:
|
|
| 118 |
]
|
| 119 |
yield BatchFeature({
|
| 120 |
**self.encode_func(list_x),
|
|
|
|
| 121 |
"chunk_seq_no": input_seq_no.clone(),
|
| 122 |
})
|
| 123 |
|
|
@@ -135,6 +136,7 @@ class ExtractorIterator:
|
|
| 135 |
]
|
| 136 |
yield BatchFeature({
|
| 137 |
**self.encode_func(list_x),
|
|
|
|
| 138 |
"chunk_seq_no": input_seq_no[:batch_num].clone(),
|
| 139 |
})
|
| 140 |
|
|
@@ -147,6 +149,9 @@ class XYTokenizerFeatureExtractor(WhisperFeatureExtractor):
|
|
| 147 |
hop_length=160,
|
| 148 |
chunk_length=30,
|
| 149 |
n_fft=400,
|
|
|
|
|
|
|
|
|
|
| 150 |
padding_value=0.0,
|
| 151 |
dither=0.0,
|
| 152 |
return_attention_mask=False,
|
|
@@ -163,6 +168,9 @@ class XYTokenizerFeatureExtractor(WhisperFeatureExtractor):
|
|
| 163 |
padding_value=padding_value,
|
| 164 |
dither=dither,
|
| 165 |
return_attention_mask=return_attention_mask,
|
|
|
|
|
|
|
|
|
|
| 166 |
**kwargs,
|
| 167 |
)
|
| 168 |
self.max_frequency = max_frequency if max_frequency is not None else sampling_rate / 2
|
|
|
|
| 82 |
x_padded = F.pad(x, (0, padding_size), "constant", 0)
|
| 83 |
output_tensor = x_padded.unfold(dimension=2, size=kernel, step=stride).squeeze(0).transpose(0, 1)
|
| 84 |
output_lengths = torch.full((num_chunks,), kernel, dtype=torch.long)
|
| 85 |
+
for i in range(num_chunks):
|
| 86 |
+
output_lengths[i] = min(output_lengths[i], L - stride * i)
|
| 87 |
output_seq_no = torch.full((num_chunks,), seq_no, dtype=torch.long)
|
| 88 |
return output_tensor, output_lengths, output_seq_no
|
| 89 |
|
|
|
|
| 118 |
]
|
| 119 |
yield BatchFeature({
|
| 120 |
**self.encode_func(list_x),
|
| 121 |
+
"input_lengths": input_lengths,
|
| 122 |
"chunk_seq_no": input_seq_no.clone(),
|
| 123 |
})
|
| 124 |
|
|
|
|
| 136 |
]
|
| 137 |
yield BatchFeature({
|
| 138 |
**self.encode_func(list_x),
|
| 139 |
+
"input_lengths": input_lengths,
|
| 140 |
"chunk_seq_no": input_seq_no[:batch_num].clone(),
|
| 141 |
})
|
| 142 |
|
|
|
|
| 149 |
hop_length=160,
|
| 150 |
chunk_length=30,
|
| 151 |
n_fft=400,
|
| 152 |
+
n_samples=480000,
|
| 153 |
+
nb_max_frames=3000,
|
| 154 |
+
padding_side="right",
|
| 155 |
padding_value=0.0,
|
| 156 |
dither=0.0,
|
| 157 |
return_attention_mask=False,
|
|
|
|
| 168 |
padding_value=padding_value,
|
| 169 |
dither=dither,
|
| 170 |
return_attention_mask=return_attention_mask,
|
| 171 |
+
n_samples=n_samples,
|
| 172 |
+
nb_max_frames=nb_max_frames,
|
| 173 |
+
padding_side=padding_side,
|
| 174 |
**kwargs,
|
| 175 |
)
|
| 176 |
self.max_frequency = max_frequency if max_frequency is not None else sampling_rate / 2
|
modeling_xy_tokenizer.py
CHANGED
|
@@ -894,10 +894,9 @@ class XYTokenizerModel(XYTokenizerPreTrainedModel):
|
|
| 894 |
|
| 895 |
# 1. Iterate through chunks and store intermediate results
|
| 896 |
for chunk_features in features:
|
| 897 |
-
code_duration_length = features.duration_size // self.encoder_downsample_rate
|
| 898 |
# Always use return_dict=True for easier access to named outputs
|
| 899 |
chunk_output = self._encode(chunk_features, n_quantizers, return_dict=True)
|
| 900 |
-
valid_code_lengths = torch.clamp(
|
| 901 |
|
| 902 |
# Accumulate weighted commit loss
|
| 903 |
chunk_length = chunk_output.codes_lengths.sum().item()
|
|
|
|
| 894 |
|
| 895 |
# 1. Iterate through chunks and store intermediate results
|
| 896 |
for chunk_features in features:
|
|
|
|
| 897 |
# Always use return_dict=True for easier access to named outputs
|
| 898 |
chunk_output = self._encode(chunk_features, n_quantizers, return_dict=True)
|
| 899 |
+
valid_code_lengths = torch.clamp(chunk_features["input_lengths"], 0, features.duration_size) // self.encoder_downsample_rate
|
| 900 |
|
| 901 |
# Accumulate weighted commit loss
|
| 902 |
chunk_length = chunk_output.codes_lengths.sum().item()
|