MCplayer commited on
Commit
109665c
·
1 Parent(s): cd99e3c
feature_extraction_xy_tokenizer.py CHANGED
@@ -82,8 +82,8 @@ class ExtractorIterator:
82
  x_padded = F.pad(x, (0, padding_size), "constant", 0)
83
  output_tensor = x_padded.unfold(dimension=2, size=kernel, step=stride).squeeze(0).transpose(0, 1)
84
  output_lengths = torch.full((num_chunks,), kernel, dtype=torch.long)
85
- if padding_size > 0:
86
- output_lengths[-1] = kernel - padding_size
87
  output_seq_no = torch.full((num_chunks,), seq_no, dtype=torch.long)
88
  return output_tensor, output_lengths, output_seq_no
89
 
@@ -118,6 +118,7 @@ class ExtractorIterator:
118
  ]
119
  yield BatchFeature({
120
  **self.encode_func(list_x),
 
121
  "chunk_seq_no": input_seq_no.clone(),
122
  })
123
 
@@ -135,6 +136,7 @@ class ExtractorIterator:
135
  ]
136
  yield BatchFeature({
137
  **self.encode_func(list_x),
 
138
  "chunk_seq_no": input_seq_no[:batch_num].clone(),
139
  })
140
 
@@ -147,6 +149,9 @@ class XYTokenizerFeatureExtractor(WhisperFeatureExtractor):
147
  hop_length=160,
148
  chunk_length=30,
149
  n_fft=400,
 
 
 
150
  padding_value=0.0,
151
  dither=0.0,
152
  return_attention_mask=False,
@@ -163,6 +168,9 @@ class XYTokenizerFeatureExtractor(WhisperFeatureExtractor):
163
  padding_value=padding_value,
164
  dither=dither,
165
  return_attention_mask=return_attention_mask,
 
 
 
166
  **kwargs,
167
  )
168
  self.max_frequency = max_frequency if max_frequency is not None else sampling_rate / 2
 
82
  x_padded = F.pad(x, (0, padding_size), "constant", 0)
83
  output_tensor = x_padded.unfold(dimension=2, size=kernel, step=stride).squeeze(0).transpose(0, 1)
84
  output_lengths = torch.full((num_chunks,), kernel, dtype=torch.long)
85
+ for i in range(num_chunks):
86
+ output_lengths[i] = min(output_lengths[i], L - stride * i)
87
  output_seq_no = torch.full((num_chunks,), seq_no, dtype=torch.long)
88
  return output_tensor, output_lengths, output_seq_no
89
 
 
118
  ]
119
  yield BatchFeature({
120
  **self.encode_func(list_x),
121
+ "input_lengths": input_lengths,
122
  "chunk_seq_no": input_seq_no.clone(),
123
  })
124
 
 
136
  ]
137
  yield BatchFeature({
138
  **self.encode_func(list_x),
139
+ "input_lengths": input_lengths,
140
  "chunk_seq_no": input_seq_no[:batch_num].clone(),
141
  })
142
 
 
149
  hop_length=160,
150
  chunk_length=30,
151
  n_fft=400,
152
+ n_samples=480000,
153
+ nb_max_frames=3000,
154
+ padding_side="right",
155
  padding_value=0.0,
156
  dither=0.0,
157
  return_attention_mask=False,
 
168
  padding_value=padding_value,
169
  dither=dither,
170
  return_attention_mask=return_attention_mask,
171
+ n_samples=n_samples,
172
+ nb_max_frames=nb_max_frames,
173
+ padding_side=padding_side,
174
  **kwargs,
175
  )
176
  self.max_frequency = max_frequency if max_frequency is not None else sampling_rate / 2
modeling_xy_tokenizer.py CHANGED
@@ -894,10 +894,9 @@ class XYTokenizerModel(XYTokenizerPreTrainedModel):
894
 
895
  # 1. Iterate through chunks and store intermediate results
896
  for chunk_features in features:
897
- code_duration_length = features.duration_size // self.encoder_downsample_rate
898
  # Always use return_dict=True for easier access to named outputs
899
  chunk_output = self._encode(chunk_features, n_quantizers, return_dict=True)
900
- valid_code_lengths = torch.clamp(chunk_output.codes_lengths, 0, code_duration_length)
901
 
902
  # Accumulate weighted commit loss
903
  chunk_length = chunk_output.codes_lengths.sum().item()
 
894
 
895
  # 1. Iterate through chunks and store intermediate results
896
  for chunk_features in features:
 
897
  # Always use return_dict=True for easier access to named outputs
898
  chunk_output = self._encode(chunk_features, n_quantizers, return_dict=True)
899
+ valid_code_lengths = torch.clamp(chunk_features["input_lengths"], 0, features.duration_size) // self.encoder_downsample_rate
900
 
901
  # Accumulate weighted commit loss
902
  chunk_length = chunk_output.codes_lengths.sum().item()