add middle overlap type
Browse files- feature_extraction_xy_tokenizer.py +73 -36
- modeling_xy_tokenizer.py +17 -7
- preprocessor_config.json +2 -1
feature_extraction_xy_tokenizer.py
CHANGED
|
@@ -18,6 +18,7 @@ Feature extractor class for Whisper
|
|
| 18 |
import math
|
| 19 |
from functools import partial
|
| 20 |
from typing import List, Optional, Union
|
|
|
|
| 21 |
|
| 22 |
import torch
|
| 23 |
import torch.nn.functional as F
|
|
@@ -34,9 +35,10 @@ class ExtractorIterator:
|
|
| 34 |
def __init__(
|
| 35 |
self,
|
| 36 |
data,
|
| 37 |
-
batch_size=
|
| 38 |
chunk_length=30,
|
| 39 |
-
overlap_seconds=10,
|
|
|
|
| 40 |
sampling_rate=16000,
|
| 41 |
encode_func = None,
|
| 42 |
) -> None:
|
|
@@ -44,12 +46,16 @@ class ExtractorIterator:
|
|
| 44 |
self.batch_size = batch_size
|
| 45 |
self.chunk_length = chunk_length
|
| 46 |
self.overlap_seconds = overlap_seconds
|
|
|
|
| 47 |
self.sampling_rate = sampling_rate
|
| 48 |
|
| 49 |
# duration_size 是每次处理的有效音频长度
|
| 50 |
self.chunk_size = int(self.chunk_length * self.sampling_rate)
|
| 51 |
-
self.
|
| 52 |
-
self.duration_size =
|
|
|
|
|
|
|
|
|
|
| 53 |
# 注意:这里我们只处理不带重叠的块,重叠将在外部处理(如果需要)
|
| 54 |
# 或者在迭代器内部更明确地处理。为了简化,我们假设分块是基于 duration_size
|
| 55 |
|
|
@@ -66,29 +72,13 @@ class ExtractorIterator:
|
|
| 66 |
|
| 67 |
# 注意:chunk_and_pad_view 输出的块大小是 duration_size
|
| 68 |
wav_tensor = torch.zeros(self.batch_size, 1, self.chunk_size)
|
| 69 |
-
input_lengths =
|
| 70 |
input_seq_no = torch.zeros(self.batch_size, dtype=torch.long)
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
x = tensor[0:1, :].unsqueeze(0)
|
| 74 |
-
|
| 75 |
-
stride = self.duration_size
|
| 76 |
-
kernel = self.chunk_size
|
| 77 |
-
B, C, L = x.shape
|
| 78 |
-
|
| 79 |
-
num_chunks = math.ceil(L / stride)
|
| 80 |
-
target_len = (num_chunks - 1) * stride + kernel
|
| 81 |
-
padding_size = max(0, target_len - L)
|
| 82 |
-
x_padded = F.pad(x, (0, padding_size), "constant", 0)
|
| 83 |
-
output_tensor = x_padded.unfold(dimension=2, size=kernel, step=stride).squeeze(0).transpose(0, 1)
|
| 84 |
-
output_lengths = torch.full((num_chunks,), kernel, dtype=torch.long)
|
| 85 |
-
for i in range(num_chunks):
|
| 86 |
-
output_lengths[i] = min(output_lengths[i], L - stride * i)
|
| 87 |
-
output_seq_no = torch.full((num_chunks,), seq_no, dtype=torch.long)
|
| 88 |
-
return output_tensor, output_lengths, output_seq_no
|
| 89 |
|
| 90 |
for i, sample in enumerate(self.data):
|
| 91 |
-
sample_chunks, sample_lengths, sample_seq_no = chunk_and_pad_view(sample, i)
|
| 92 |
|
| 93 |
processed_in_sample = 0
|
| 94 |
while processed_in_sample < len(sample_chunks):
|
|
@@ -103,7 +93,7 @@ class ExtractorIterator:
|
|
| 103 |
|
| 104 |
# 填充数据
|
| 105 |
wav_tensor[start_idx_batch:end_idx_batch] = sample_chunks[start_idx_sample:end_idx_sample]
|
| 106 |
-
input_lengths
|
| 107 |
input_seq_no[start_idx_batch:end_idx_batch] = sample_seq_no[start_idx_sample:end_idx_sample]
|
| 108 |
|
| 109 |
# 更新计数器
|
|
@@ -112,10 +102,13 @@ class ExtractorIterator:
|
|
| 112 |
|
| 113 |
# 如果批次满了,yield 一个副本并重置
|
| 114 |
if batch_num == self.batch_size:
|
| 115 |
-
list_x = [
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
|
|
|
|
|
|
|
|
|
| 119 |
yield BatchFeature({
|
| 120 |
**self.encode_func(list_x),
|
| 121 |
"input_lengths": input_lengths,
|
|
@@ -125,21 +118,62 @@ class ExtractorIterator:
|
|
| 125 |
# 重置批次计数器和Tensor内容
|
| 126 |
batch_num = 0
|
| 127 |
wav_tensor.zero_()
|
| 128 |
-
input_lengths.
|
| 129 |
input_seq_no.zero_()
|
| 130 |
|
| 131 |
# 循环结束后,处理最后一个未满的批次
|
| 132 |
if batch_num > 0:
|
| 133 |
-
list_x = [
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
|
|
|
|
|
|
|
|
|
| 137 |
yield BatchFeature({
|
| 138 |
**self.encode_func(list_x),
|
| 139 |
"input_lengths": input_lengths,
|
| 140 |
"chunk_seq_no": input_seq_no[:batch_num].clone(),
|
| 141 |
})
|
| 142 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
|
| 144 |
class XYTokenizerFeatureExtractor(WhisperFeatureExtractor):
|
| 145 |
def __init__(
|
|
@@ -156,7 +190,8 @@ class XYTokenizerFeatureExtractor(WhisperFeatureExtractor):
|
|
| 156 |
dither=0.0,
|
| 157 |
return_attention_mask=False,
|
| 158 |
max_frequency=None,
|
| 159 |
-
batch_size=
|
|
|
|
| 160 |
**kwargs,
|
| 161 |
):
|
| 162 |
super().__init__(
|
|
@@ -184,6 +219,7 @@ class XYTokenizerFeatureExtractor(WhisperFeatureExtractor):
|
|
| 184 |
norm="slaney",
|
| 185 |
mel_scale="slaney",
|
| 186 |
)
|
|
|
|
| 187 |
|
| 188 |
def __call__(
|
| 189 |
self,
|
|
@@ -207,9 +243,10 @@ class XYTokenizerFeatureExtractor(WhisperFeatureExtractor):
|
|
| 207 |
|
| 208 |
return ExtractorIterator(
|
| 209 |
raw_speech,
|
| 210 |
-
batch_size=
|
| 211 |
chunk_length=self.chunk_length,
|
| 212 |
overlap_seconds=overlap_seconds,
|
|
|
|
| 213 |
sampling_rate=self.sampling_rate,
|
| 214 |
encode_func=partial(
|
| 215 |
super().__call__,
|
|
|
|
| 18 |
import math
|
| 19 |
from functools import partial
|
| 20 |
from typing import List, Optional, Union
|
| 21 |
+
from collections import deque
|
| 22 |
|
| 23 |
import torch
|
| 24 |
import torch.nn.functional as F
|
|
|
|
| 35 |
def __init__(
|
| 36 |
self,
|
| 37 |
data,
|
| 38 |
+
batch_size=8,
|
| 39 |
chunk_length=30,
|
| 40 |
+
overlap_seconds=10,
|
| 41 |
+
overlap_side="both",
|
| 42 |
sampling_rate=16000,
|
| 43 |
encode_func = None,
|
| 44 |
) -> None:
|
|
|
|
| 46 |
self.batch_size = batch_size
|
| 47 |
self.chunk_length = chunk_length
|
| 48 |
self.overlap_seconds = overlap_seconds
|
| 49 |
+
self.overlap_side = overlap_side
|
| 50 |
self.sampling_rate = sampling_rate
|
| 51 |
|
| 52 |
# duration_size 是每次处理的有效音频长度
|
| 53 |
self.chunk_size = int(self.chunk_length * self.sampling_rate)
|
| 54 |
+
self.overlap_size = int(self.overlap_seconds * self.sampling_rate)
|
| 55 |
+
self.duration_size = self.chunk_size - self.overlap_size
|
| 56 |
+
assert (
|
| 57 |
+
(overlap_side == "right") or (self.overlap_size % 2 == 0)
|
| 58 |
+
), '`overlap_seconds` must be divisible by 2 when `overlap_side` is "both".'
|
| 59 |
# 注意:这里我们只处理不带重叠的块,重叠将在外部处理(如果需要)
|
| 60 |
# 或者在迭代器内部更明确地处理。为了简化,我们假设分块是基于 duration_size
|
| 61 |
|
|
|
|
| 72 |
|
| 73 |
# 注意:chunk_and_pad_view 输出的块大小是 duration_size
|
| 74 |
wav_tensor = torch.zeros(self.batch_size, 1, self.chunk_size)
|
| 75 |
+
input_lengths = deque(maxlen=self.batch_size)
|
| 76 |
input_seq_no = torch.zeros(self.batch_size, dtype=torch.long)
|
| 77 |
+
|
| 78 |
+
right_boundary = self.get_right_boundary()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
|
| 80 |
for i, sample in enumerate(self.data):
|
| 81 |
+
sample_chunks, sample_lengths, sample_seq_no = self.chunk_and_pad_view(sample, i)
|
| 82 |
|
| 83 |
processed_in_sample = 0
|
| 84 |
while processed_in_sample < len(sample_chunks):
|
|
|
|
| 93 |
|
| 94 |
# 填充数据
|
| 95 |
wav_tensor[start_idx_batch:end_idx_batch] = sample_chunks[start_idx_sample:end_idx_sample]
|
| 96 |
+
input_lengths.extend(sample_lengths[start_idx_sample:end_idx_sample])
|
| 97 |
input_seq_no[start_idx_batch:end_idx_batch] = sample_seq_no[start_idx_sample:end_idx_sample]
|
| 98 |
|
| 99 |
# 更新计数器
|
|
|
|
| 102 |
|
| 103 |
# 如果批次满了,yield 一个副本并重置
|
| 104 |
if batch_num == self.batch_size:
|
| 105 |
+
list_x = []
|
| 106 |
+
for xi, (_, right) in enumerate(input_lengths):
|
| 107 |
+
if right == right_boundary and torch.any(wav_tensor[xi, :, right:] != 0):
|
| 108 |
+
list_x.append(wav_tensor[xi].reshape(-1).cpu().numpy())
|
| 109 |
+
else:
|
| 110 |
+
list_x.append(wav_tensor[xi, :, :right].reshape(-1).cpu().numpy())
|
| 111 |
+
|
| 112 |
yield BatchFeature({
|
| 113 |
**self.encode_func(list_x),
|
| 114 |
"input_lengths": input_lengths,
|
|
|
|
| 118 |
# 重置批次计数器和Tensor内容
|
| 119 |
batch_num = 0
|
| 120 |
wav_tensor.zero_()
|
| 121 |
+
input_lengths.clear()
|
| 122 |
input_seq_no.zero_()
|
| 123 |
|
| 124 |
# 循环结束后,处理最后一个未满的批次
|
| 125 |
if batch_num > 0:
|
| 126 |
+
list_x = []
|
| 127 |
+
for xi in range(batch_num):
|
| 128 |
+
_, right = input_lengths[xi]
|
| 129 |
+
if right == right_boundary and torch.any(wav_tensor[xi, :, right:] != 0):
|
| 130 |
+
list_x.append(wav_tensor[xi].reshape(-1).cpu().numpy())
|
| 131 |
+
else:
|
| 132 |
+
list_x.append(wav_tensor[xi, :, :right].reshape(-1).cpu().numpy())
|
| 133 |
yield BatchFeature({
|
| 134 |
**self.encode_func(list_x),
|
| 135 |
"input_lengths": input_lengths,
|
| 136 |
"chunk_seq_no": input_seq_no[:batch_num].clone(),
|
| 137 |
})
|
| 138 |
|
| 139 |
+
def chunk_and_pad_view(self, tensor, seq_no):
|
| 140 |
+
x = tensor[0:1, :].unsqueeze(0)
|
| 141 |
+
|
| 142 |
+
stride = self.duration_size
|
| 143 |
+
kernel = self.chunk_size
|
| 144 |
+
B, C, L = x.shape
|
| 145 |
+
|
| 146 |
+
num_chunks = max(0, math.ceil((L - kernel) / stride)) + 1
|
| 147 |
+
target_len = (num_chunks - 1) * stride + kernel
|
| 148 |
+
padding_size = max(0, target_len - L)
|
| 149 |
+
x_padded = F.pad(x, (0, padding_size), "constant", 0)
|
| 150 |
+
output_tensor = x_padded.unfold(dimension=2, size=kernel, step=stride).squeeze(0).transpose(0, 1)
|
| 151 |
+
|
| 152 |
+
output_lengths = self.get_windows_boundaries(num_chunks, L)
|
| 153 |
+
output_seq_no = torch.full((num_chunks,), seq_no, dtype=torch.long)
|
| 154 |
+
return output_tensor, output_lengths, output_seq_no
|
| 155 |
+
|
| 156 |
+
def get_left_boundary(self):
|
| 157 |
+
if self.overlap_side == "right":
|
| 158 |
+
return 0
|
| 159 |
+
else:
|
| 160 |
+
return int(self.overlap_size / 2)
|
| 161 |
+
|
| 162 |
+
def get_right_boundary(self):
|
| 163 |
+
if self.overlap_side == "right":
|
| 164 |
+
return self.duration_size
|
| 165 |
+
else:
|
| 166 |
+
return self.chunk_size - int(self.overlap_size / 2)
|
| 167 |
+
|
| 168 |
+
def get_windows_boundaries(self, num_chunks, seq_len):
|
| 169 |
+
left_boundary = self.get_left_boundary()
|
| 170 |
+
right_boundary = self.get_right_boundary()
|
| 171 |
+
|
| 172 |
+
output_lengths = [(left_boundary, right_boundary) for _ in range(num_chunks)]
|
| 173 |
+
output_lengths[0] = (0, output_lengths[0][1])
|
| 174 |
+
output_lengths[-1] = (output_lengths[-1][0], seq_len - self.duration_size * (num_chunks-1))
|
| 175 |
+
return output_lengths
|
| 176 |
+
|
| 177 |
|
| 178 |
class XYTokenizerFeatureExtractor(WhisperFeatureExtractor):
|
| 179 |
def __init__(
|
|
|
|
| 190 |
dither=0.0,
|
| 191 |
return_attention_mask=False,
|
| 192 |
max_frequency=None,
|
| 193 |
+
batch_size=8,
|
| 194 |
+
overlap_side="both",
|
| 195 |
**kwargs,
|
| 196 |
):
|
| 197 |
super().__init__(
|
|
|
|
| 219 |
norm="slaney",
|
| 220 |
mel_scale="slaney",
|
| 221 |
)
|
| 222 |
+
self.overlap_side = overlap_side
|
| 223 |
|
| 224 |
def __call__(
|
| 225 |
self,
|
|
|
|
| 243 |
|
| 244 |
return ExtractorIterator(
|
| 245 |
raw_speech,
|
| 246 |
+
batch_size=self.batch_size if self.batch_size else len(raw_speech),
|
| 247 |
chunk_length=self.chunk_length,
|
| 248 |
overlap_seconds=overlap_seconds,
|
| 249 |
+
overlap_side=self.overlap_side,
|
| 250 |
sampling_rate=self.sampling_rate,
|
| 251 |
encode_func=partial(
|
| 252 |
super().__call__,
|
modeling_xy_tokenizer.py
CHANGED
|
@@ -858,6 +858,16 @@ class XYTokenizerModel(XYTokenizerPreTrainedModel):
|
|
| 858 |
|
| 859 |
return torch.tensor([_get_out_len(l) for l in input_lengths], device=self.device)
|
| 860 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 861 |
@torch.inference_mode
|
| 862 |
def encode(
|
| 863 |
self,
|
|
@@ -896,11 +906,11 @@ class XYTokenizerModel(XYTokenizerPreTrainedModel):
|
|
| 896 |
for chunk_features in features:
|
| 897 |
# Always use return_dict=True for easier access to named outputs
|
| 898 |
chunk_output = self._encode(chunk_features, n_quantizers, return_dict=True)
|
| 899 |
-
valid_code_lengths =
|
| 900 |
|
| 901 |
# Accumulate weighted commit loss
|
| 902 |
chunk_length = chunk_output.codes_lengths.sum().item()
|
| 903 |
-
valid_chunk_length =
|
| 904 |
if chunk_output.commit_loss is not None and valid_chunk_length > 0:
|
| 905 |
commit_loss = chunk_output.commit_loss / chunk_length * valid_chunk_length
|
| 906 |
commit_losses.append((commit_loss.cpu(), valid_chunk_length))
|
|
@@ -908,12 +918,12 @@ class XYTokenizerModel(XYTokenizerPreTrainedModel):
|
|
| 908 |
|
| 909 |
# Group results by original sequence ID
|
| 910 |
for i, seq_id in enumerate(chunk_features["chunk_seq_no"].tolist()):
|
| 911 |
-
|
| 912 |
-
if
|
| 913 |
-
encodings[seq_id]["zq"].append(chunk_output.quantized_representation[i:i+1, :,
|
| 914 |
-
encodings[seq_id]["codes"].append(chunk_output.audio_codes[:, i:i+1,
|
| 915 |
# Add the valid length of this chunk to the total for this sequence
|
| 916 |
-
encodings[seq_id]["length"] += valid_code_lengths[i]
|
| 917 |
|
| 918 |
final_outputs = []
|
| 919 |
for seq_id, seq_data in encodings.items():
|
|
|
|
| 858 |
|
| 859 |
return torch.tensor([_get_out_len(l) for l in input_lengths], device=self.device)
|
| 860 |
|
| 861 |
+
def scale_window_size(self, boundaries, scaling_factor):
|
| 862 |
+
scaling_range = []
|
| 863 |
+
scaling_boundaries = []
|
| 864 |
+
for left_boundary, right_boundary in boundaries:
|
| 865 |
+
scaling_left_boundary = left_boundary// scaling_factor
|
| 866 |
+
scaling_right_boundary = right_boundary // scaling_factor
|
| 867 |
+
scaling_range.append(scaling_right_boundary-scaling_left_boundary)
|
| 868 |
+
scaling_boundaries.append(slice(scaling_left_boundary, scaling_right_boundary))
|
| 869 |
+
return scaling_range, scaling_boundaries
|
| 870 |
+
|
| 871 |
@torch.inference_mode
|
| 872 |
def encode(
|
| 873 |
self,
|
|
|
|
| 906 |
for chunk_features in features:
|
| 907 |
# Always use return_dict=True for easier access to named outputs
|
| 908 |
chunk_output = self._encode(chunk_features, n_quantizers, return_dict=True)
|
| 909 |
+
valid_code_lengths, valid_code_ranges = self.scale_window_size(chunk_features["input_lengths"], self.encoder_downsample_rate)
|
| 910 |
|
| 911 |
# Accumulate weighted commit loss
|
| 912 |
chunk_length = chunk_output.codes_lengths.sum().item()
|
| 913 |
+
valid_chunk_length = sum(valid_code_lengths)
|
| 914 |
if chunk_output.commit_loss is not None and valid_chunk_length > 0:
|
| 915 |
commit_loss = chunk_output.commit_loss / chunk_length * valid_chunk_length
|
| 916 |
commit_losses.append((commit_loss.cpu(), valid_chunk_length))
|
|
|
|
| 918 |
|
| 919 |
# Group results by original sequence ID
|
| 920 |
for i, seq_id in enumerate(chunk_features["chunk_seq_no"].tolist()):
|
| 921 |
+
valid_code_range = valid_code_ranges[i]
|
| 922 |
+
if valid_code_range.stop > 0:
|
| 923 |
+
encodings[seq_id]["zq"].append(chunk_output.quantized_representation[i:i+1, :, valid_code_range])
|
| 924 |
+
encodings[seq_id]["codes"].append(chunk_output.audio_codes[:, i:i+1, valid_code_range])
|
| 925 |
# Add the valid length of this chunk to the total for this sequence
|
| 926 |
+
encodings[seq_id]["length"] += valid_code_lengths[i]
|
| 927 |
|
| 928 |
final_outputs = []
|
| 929 |
for seq_id, seq_data in encodings.items():
|
preprocessor_config.json
CHANGED
|
@@ -9,5 +9,6 @@
|
|
| 9 |
"padding_value": 0.0,
|
| 10 |
"sampling_rate": 16000,
|
| 11 |
"return_attention_mask": true,
|
| 12 |
-
"return_tensors": "pt"
|
|
|
|
| 13 |
}
|
|
|
|
| 9 |
"padding_value": 0.0,
|
| 10 |
"sampling_rate": 16000,
|
| 11 |
"return_attention_mask": true,
|
| 12 |
+
"return_tensors": "pt",
|
| 13 |
+
"overlap_side": "both"
|
| 14 |
}
|