MCplayer commited on
Commit
f038693
·
1 Parent(s): 109665c

add middle overlap type

Browse files
feature_extraction_xy_tokenizer.py CHANGED
@@ -18,6 +18,7 @@ Feature extractor class for Whisper
18
  import math
19
  from functools import partial
20
  from typing import List, Optional, Union
 
21
 
22
  import torch
23
  import torch.nn.functional as F
@@ -34,9 +35,10 @@ class ExtractorIterator:
34
  def __init__(
35
  self,
36
  data,
37
- batch_size=1,
38
  chunk_length=30,
39
- overlap_seconds=10,
 
40
  sampling_rate=16000,
41
  encode_func = None,
42
  ) -> None:
@@ -44,12 +46,16 @@ class ExtractorIterator:
44
  self.batch_size = batch_size
45
  self.chunk_length = chunk_length
46
  self.overlap_seconds = overlap_seconds
 
47
  self.sampling_rate = sampling_rate
48
 
49
  # duration_size 是每次处理的有效音频长度
50
  self.chunk_size = int(self.chunk_length * self.sampling_rate)
51
- self.duration_seconds = self.chunk_length - self.overlap_seconds
52
- self.duration_size = int(self.duration_seconds * self.sampling_rate)
 
 
 
53
  # 注意:这里我们只处理不带重叠的块,重叠将在外部处理(如果需要)
54
  # 或者在迭代器内部更明确地处理。为了简化,我们假设分块是基于 duration_size
55
 
@@ -66,29 +72,13 @@ class ExtractorIterator:
66
 
67
  # 注意:chunk_and_pad_view 输出的块大小是 duration_size
68
  wav_tensor = torch.zeros(self.batch_size, 1, self.chunk_size)
69
- input_lengths = torch.zeros(self.batch_size, dtype=torch.long)
70
  input_seq_no = torch.zeros(self.batch_size, dtype=torch.long)
71
-
72
- def chunk_and_pad_view(tensor, seq_no):
73
- x = tensor[0:1, :].unsqueeze(0)
74
-
75
- stride = self.duration_size
76
- kernel = self.chunk_size
77
- B, C, L = x.shape
78
-
79
- num_chunks = math.ceil(L / stride)
80
- target_len = (num_chunks - 1) * stride + kernel
81
- padding_size = max(0, target_len - L)
82
- x_padded = F.pad(x, (0, padding_size), "constant", 0)
83
- output_tensor = x_padded.unfold(dimension=2, size=kernel, step=stride).squeeze(0).transpose(0, 1)
84
- output_lengths = torch.full((num_chunks,), kernel, dtype=torch.long)
85
- for i in range(num_chunks):
86
- output_lengths[i] = min(output_lengths[i], L - stride * i)
87
- output_seq_no = torch.full((num_chunks,), seq_no, dtype=torch.long)
88
- return output_tensor, output_lengths, output_seq_no
89
 
90
  for i, sample in enumerate(self.data):
91
- sample_chunks, sample_lengths, sample_seq_no = chunk_and_pad_view(sample, i)
92
 
93
  processed_in_sample = 0
94
  while processed_in_sample < len(sample_chunks):
@@ -103,7 +93,7 @@ class ExtractorIterator:
103
 
104
  # 填充数据
105
  wav_tensor[start_idx_batch:end_idx_batch] = sample_chunks[start_idx_sample:end_idx_sample]
106
- input_lengths[start_idx_batch:end_idx_batch] = sample_lengths[start_idx_sample:end_idx_sample]
107
  input_seq_no[start_idx_batch:end_idx_batch] = sample_seq_no[start_idx_sample:end_idx_sample]
108
 
109
  # 更新计数器
@@ -112,10 +102,13 @@ class ExtractorIterator:
112
 
113
  # 如果批次满了,yield 一个副本并重置
114
  if batch_num == self.batch_size:
115
- list_x = [
116
- wav_tensor[xi, :, :x_len].reshape(-1).cpu().numpy()
117
- for xi, x_len in enumerate(input_lengths.tolist())
118
- ]
 
 
 
119
  yield BatchFeature({
120
  **self.encode_func(list_x),
121
  "input_lengths": input_lengths,
@@ -125,21 +118,62 @@ class ExtractorIterator:
125
  # 重置批次计数器和Tensor内容
126
  batch_num = 0
127
  wav_tensor.zero_()
128
- input_lengths.zero_()
129
  input_seq_no.zero_()
130
 
131
  # 循环结束后,处理最后一个未满的批次
132
  if batch_num > 0:
133
- list_x = [
134
- wav_tensor[xi, :, :x_len].reshape(-1).cpu().numpy()
135
- for xi, x_len in enumerate(input_lengths.tolist())
136
- ]
 
 
 
137
  yield BatchFeature({
138
  **self.encode_func(list_x),
139
  "input_lengths": input_lengths,
140
  "chunk_seq_no": input_seq_no[:batch_num].clone(),
141
  })
142
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
  class XYTokenizerFeatureExtractor(WhisperFeatureExtractor):
145
  def __init__(
@@ -156,7 +190,8 @@ class XYTokenizerFeatureExtractor(WhisperFeatureExtractor):
156
  dither=0.0,
157
  return_attention_mask=False,
158
  max_frequency=None,
159
- batch_size=None,
 
160
  **kwargs,
161
  ):
162
  super().__init__(
@@ -184,6 +219,7 @@ class XYTokenizerFeatureExtractor(WhisperFeatureExtractor):
184
  norm="slaney",
185
  mel_scale="slaney",
186
  )
 
187
 
188
  def __call__(
189
  self,
@@ -207,9 +243,10 @@ class XYTokenizerFeatureExtractor(WhisperFeatureExtractor):
207
 
208
  return ExtractorIterator(
209
  raw_speech,
210
- batch_size=len(raw_speech) if self.batch_size is None else self.batch_size,
211
  chunk_length=self.chunk_length,
212
  overlap_seconds=overlap_seconds,
 
213
  sampling_rate=self.sampling_rate,
214
  encode_func=partial(
215
  super().__call__,
 
18
  import math
19
  from functools import partial
20
  from typing import List, Optional, Union
21
+ from collections import deque
22
 
23
  import torch
24
  import torch.nn.functional as F
 
35
  def __init__(
36
  self,
37
  data,
38
+ batch_size=8,
39
  chunk_length=30,
40
+ overlap_seconds=10,
41
+ overlap_side="both",
42
  sampling_rate=16000,
43
  encode_func = None,
44
  ) -> None:
 
46
  self.batch_size = batch_size
47
  self.chunk_length = chunk_length
48
  self.overlap_seconds = overlap_seconds
49
+ self.overlap_side = overlap_side
50
  self.sampling_rate = sampling_rate
51
 
52
  # duration_size 是每次处理的有效音频长度
53
  self.chunk_size = int(self.chunk_length * self.sampling_rate)
54
+ self.overlap_size = int(self.overlap_seconds * self.sampling_rate)
55
+ self.duration_size = self.chunk_size - self.overlap_size
56
+ assert (
57
+ (overlap_side == "right") or (self.overlap_size % 2 == 0)
58
+ ), '`overlap_seconds` must be divisible by 2 when `overlap_side` is "both".'
59
  # 注意:这里我们只处理不带重叠的块,重叠将在外部处理(如果需要)
60
  # 或者在迭代器内部更明确地处理。为了简化,我们假设分块是基于 duration_size
61
 
 
72
 
73
  # 注意:chunk_and_pad_view 输出的块大小是 duration_size
74
  wav_tensor = torch.zeros(self.batch_size, 1, self.chunk_size)
75
+ input_lengths = deque(maxlen=self.batch_size)
76
  input_seq_no = torch.zeros(self.batch_size, dtype=torch.long)
77
+
78
+ right_boundary = self.get_right_boundary()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
  for i, sample in enumerate(self.data):
81
+ sample_chunks, sample_lengths, sample_seq_no = self.chunk_and_pad_view(sample, i)
82
 
83
  processed_in_sample = 0
84
  while processed_in_sample < len(sample_chunks):
 
93
 
94
  # 填充数据
95
  wav_tensor[start_idx_batch:end_idx_batch] = sample_chunks[start_idx_sample:end_idx_sample]
96
+ input_lengths.extend(sample_lengths[start_idx_sample:end_idx_sample])
97
  input_seq_no[start_idx_batch:end_idx_batch] = sample_seq_no[start_idx_sample:end_idx_sample]
98
 
99
  # 更新计数器
 
102
 
103
  # 如果批次满了,yield 一个副本并重置
104
  if batch_num == self.batch_size:
105
+ list_x = []
106
+ for xi, (_, right) in enumerate(input_lengths):
107
+ if right == right_boundary and torch.any(wav_tensor[xi, :, right:] != 0):
108
+ list_x.append(wav_tensor[xi].reshape(-1).cpu().numpy())
109
+ else:
110
+ list_x.append(wav_tensor[xi, :, :right].reshape(-1).cpu().numpy())
111
+
112
  yield BatchFeature({
113
  **self.encode_func(list_x),
114
  "input_lengths": input_lengths,
 
118
  # 重置批次计数器和Tensor内容
119
  batch_num = 0
120
  wav_tensor.zero_()
121
+ input_lengths.clear()
122
  input_seq_no.zero_()
123
 
124
  # 循环结束后,处理最后一个未满的批次
125
  if batch_num > 0:
126
+ list_x = []
127
+ for xi in range(batch_num):
128
+ _, right = input_lengths[xi]
129
+ if right == right_boundary and torch.any(wav_tensor[xi, :, right:] != 0):
130
+ list_x.append(wav_tensor[xi].reshape(-1).cpu().numpy())
131
+ else:
132
+ list_x.append(wav_tensor[xi, :, :right].reshape(-1).cpu().numpy())
133
  yield BatchFeature({
134
  **self.encode_func(list_x),
135
  "input_lengths": input_lengths,
136
  "chunk_seq_no": input_seq_no[:batch_num].clone(),
137
  })
138
 
139
+ def chunk_and_pad_view(self, tensor, seq_no):
140
+ x = tensor[0:1, :].unsqueeze(0)
141
+
142
+ stride = self.duration_size
143
+ kernel = self.chunk_size
144
+ B, C, L = x.shape
145
+
146
+ num_chunks = max(0, math.ceil((L - kernel) / stride)) + 1
147
+ target_len = (num_chunks - 1) * stride + kernel
148
+ padding_size = max(0, target_len - L)
149
+ x_padded = F.pad(x, (0, padding_size), "constant", 0)
150
+ output_tensor = x_padded.unfold(dimension=2, size=kernel, step=stride).squeeze(0).transpose(0, 1)
151
+
152
+ output_lengths = self.get_windows_boundaries(num_chunks, L)
153
+ output_seq_no = torch.full((num_chunks,), seq_no, dtype=torch.long)
154
+ return output_tensor, output_lengths, output_seq_no
155
+
156
+ def get_left_boundary(self):
157
+ if self.overlap_side == "right":
158
+ return 0
159
+ else:
160
+ return int(self.overlap_size / 2)
161
+
162
+ def get_right_boundary(self):
163
+ if self.overlap_side == "right":
164
+ return self.duration_size
165
+ else:
166
+ return self.chunk_size - int(self.overlap_size / 2)
167
+
168
+ def get_windows_boundaries(self, num_chunks, seq_len):
169
+ left_boundary = self.get_left_boundary()
170
+ right_boundary = self.get_right_boundary()
171
+
172
+ output_lengths = [(left_boundary, right_boundary) for _ in range(num_chunks)]
173
+ output_lengths[0] = (0, output_lengths[0][1])
174
+ output_lengths[-1] = (output_lengths[-1][0], seq_len - self.duration_size * (num_chunks-1))
175
+ return output_lengths
176
+
177
 
178
  class XYTokenizerFeatureExtractor(WhisperFeatureExtractor):
179
  def __init__(
 
190
  dither=0.0,
191
  return_attention_mask=False,
192
  max_frequency=None,
193
+ batch_size=8,
194
+ overlap_side="both",
195
  **kwargs,
196
  ):
197
  super().__init__(
 
219
  norm="slaney",
220
  mel_scale="slaney",
221
  )
222
+ self.overlap_side = overlap_side
223
 
224
  def __call__(
225
  self,
 
243
 
244
  return ExtractorIterator(
245
  raw_speech,
246
+ batch_size=self.batch_size if self.batch_size else len(raw_speech),
247
  chunk_length=self.chunk_length,
248
  overlap_seconds=overlap_seconds,
249
+ overlap_side=self.overlap_side,
250
  sampling_rate=self.sampling_rate,
251
  encode_func=partial(
252
  super().__call__,
modeling_xy_tokenizer.py CHANGED
@@ -858,6 +858,16 @@ class XYTokenizerModel(XYTokenizerPreTrainedModel):
858
 
859
  return torch.tensor([_get_out_len(l) for l in input_lengths], device=self.device)
860
 
 
 
 
 
 
 
 
 
 
 
861
  @torch.inference_mode
862
  def encode(
863
  self,
@@ -896,11 +906,11 @@ class XYTokenizerModel(XYTokenizerPreTrainedModel):
896
  for chunk_features in features:
897
  # Always use return_dict=True for easier access to named outputs
898
  chunk_output = self._encode(chunk_features, n_quantizers, return_dict=True)
899
- valid_code_lengths = torch.clamp(chunk_features["input_lengths"], 0, features.duration_size) // self.encoder_downsample_rate
900
 
901
  # Accumulate weighted commit loss
902
  chunk_length = chunk_output.codes_lengths.sum().item()
903
- valid_chunk_length = valid_code_lengths.sum().item()
904
  if chunk_output.commit_loss is not None and valid_chunk_length > 0:
905
  commit_loss = chunk_output.commit_loss / chunk_length * valid_chunk_length
906
  commit_losses.append((commit_loss.cpu(), valid_chunk_length))
@@ -908,12 +918,12 @@ class XYTokenizerModel(XYTokenizerPreTrainedModel):
908
 
909
  # Group results by original sequence ID
910
  for i, seq_id in enumerate(chunk_features["chunk_seq_no"].tolist()):
911
- valid_code_length = valid_code_lengths[i]
912
- if valid_code_length > 0:
913
- encodings[seq_id]["zq"].append(chunk_output.quantized_representation[i:i+1, :, :valid_code_length])
914
- encodings[seq_id]["codes"].append(chunk_output.audio_codes[:, i:i+1, :valid_code_length])
915
  # Add the valid length of this chunk to the total for this sequence
916
- encodings[seq_id]["length"] += valid_code_lengths[i].item()
917
 
918
  final_outputs = []
919
  for seq_id, seq_data in encodings.items():
 
858
 
859
  return torch.tensor([_get_out_len(l) for l in input_lengths], device=self.device)
860
 
861
+ def scale_window_size(self, boundaries, scaling_factor):
862
+ scaling_range = []
863
+ scaling_boundaries = []
864
+ for left_boundary, right_boundary in boundaries:
865
+ scaling_left_boundary = left_boundary// scaling_factor
866
+ scaling_right_boundary = right_boundary // scaling_factor
867
+ scaling_range.append(scaling_right_boundary-scaling_left_boundary)
868
+ scaling_boundaries.append(slice(scaling_left_boundary, scaling_right_boundary))
869
+ return scaling_range, scaling_boundaries
870
+
871
  @torch.inference_mode
872
  def encode(
873
  self,
 
906
  for chunk_features in features:
907
  # Always use return_dict=True for easier access to named outputs
908
  chunk_output = self._encode(chunk_features, n_quantizers, return_dict=True)
909
+ valid_code_lengths, valid_code_ranges = self.scale_window_size(chunk_features["input_lengths"], self.encoder_downsample_rate)
910
 
911
  # Accumulate weighted commit loss
912
  chunk_length = chunk_output.codes_lengths.sum().item()
913
+ valid_chunk_length = sum(valid_code_lengths)
914
  if chunk_output.commit_loss is not None and valid_chunk_length > 0:
915
  commit_loss = chunk_output.commit_loss / chunk_length * valid_chunk_length
916
  commit_losses.append((commit_loss.cpu(), valid_chunk_length))
 
918
 
919
  # Group results by original sequence ID
920
  for i, seq_id in enumerate(chunk_features["chunk_seq_no"].tolist()):
921
+ valid_code_range = valid_code_ranges[i]
922
+ if valid_code_range.stop > 0:
923
+ encodings[seq_id]["zq"].append(chunk_output.quantized_representation[i:i+1, :, valid_code_range])
924
+ encodings[seq_id]["codes"].append(chunk_output.audio_codes[:, i:i+1, valid_code_range])
925
  # Add the valid length of this chunk to the total for this sequence
926
+ encodings[seq_id]["length"] += valid_code_lengths[i]
927
 
928
  final_outputs = []
929
  for seq_id, seq_data in encodings.items():
preprocessor_config.json CHANGED
@@ -9,5 +9,6 @@
9
  "padding_value": 0.0,
10
  "sampling_rate": 16000,
11
  "return_attention_mask": true,
12
- "return_tensors": "pt"
 
13
  }
 
9
  "padding_value": 0.0,
10
  "sampling_rate": 16000,
11
  "return_attention_mask": true,
12
+ "return_tensors": "pt",
13
+ "overlap_side": "both"
14
  }