XY_Tokenizer / feature_extraction_xy_tokenizer.py
MCplayer's picture
pre-release version
0b4c806
raw
history blame
8.7 kB
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Feature extractor class for Whisper
"""
import math
from functools import partial
from typing import List, Optional, Union
import torch
import torch.nn.functional as F
from transformers import WhisperFeatureExtractor
from transformers.audio_utils import mel_filter_bank
from transformers.configuration_utils import PretrainedConfig
from transformers.feature_extraction_utils import BatchFeature
from transformers.utils import TensorType, logging
logger = logging.get_logger(__name__)
class ExtractorIterator:
def __init__(
self,
data,
batch_size=1,
chunk_length=30,
overlap_seconds=10,
sampling_rate=16000,
encode_func = None,
) -> None:
self.data = data
self.batch_size = batch_size
self.chunk_length = chunk_length
self.overlap_seconds = overlap_seconds
self.sampling_rate = sampling_rate
# duration_size 是每次处理的有效音频长度
self.chunk_size = int(self.chunk_length * self.sampling_rate)
self.duration_seconds = self.chunk_length - self.overlap_seconds
self.duration_size = int(self.duration_seconds * self.sampling_rate)
# 注意:这里我们只处理不带重叠的块,重叠将在外部处理(如果需要)
# 或者在迭代器内部更明确地处理。为了简化,我们假设分块是基于 duration_size
assert callable(encode_func)
self.encode_func = encode_func
def __iter__(self):
"""
返回一个生成器,该生成器负责处理所有批处理逻辑。
这是最 Pythonic 的实现方式。
"""
# 批处理相关的变量现在是 __iter__ 的局部变量,非常清晰
batch_num = 0
# 注意:chunk_and_pad_view 输出的块大小是 duration_size
wav_tensor = torch.zeros(self.batch_size, 1, self.chunk_size)
input_lengths = torch.zeros(self.batch_size, dtype=torch.long)
input_seq_no = torch.zeros(self.batch_size, dtype=torch.long)
def chunk_and_pad_view(tensor, seq_no):
x = tensor[0:1, :].unsqueeze(0)
stride = self.duration_size
kernel = self.chunk_size
B, C, L = x.shape
num_chunks = math.ceil(L / stride)
target_len = (num_chunks - 1) * stride + kernel
padding_size = max(0, target_len - L)
x_padded = F.pad(x, (0, padding_size), "constant", 0)
output_tensor = x_padded.unfold(dimension=2, size=kernel, step=stride).squeeze(0).transpose(0, 1)
output_lengths = torch.full((num_chunks,), kernel, dtype=torch.long)
if padding_size > 0:
output_lengths[-1] = kernel - padding_size
output_seq_no = torch.full((num_chunks,), seq_no, dtype=torch.long)
return output_tensor, output_lengths, output_seq_no
for i, sample in enumerate(self.data):
sample_chunks, sample_lengths, sample_seq_no = chunk_and_pad_view(sample, i)
processed_in_sample = 0
while processed_in_sample < len(sample_chunks):
space_in_batch = self.batch_size - batch_num
chunks_to_add = min(space_in_batch, len(sample_chunks) - processed_in_sample)
# 定义切片范围
start_idx_sample = processed_in_sample
end_idx_sample = processed_in_sample + chunks_to_add
start_idx_batch = batch_num
end_idx_batch = batch_num + chunks_to_add
# 填充数据
wav_tensor[start_idx_batch:end_idx_batch] = sample_chunks[start_idx_sample:end_idx_sample]
input_lengths[start_idx_batch:end_idx_batch] = sample_lengths[start_idx_sample:end_idx_sample]
input_seq_no[start_idx_batch:end_idx_batch] = sample_seq_no[start_idx_sample:end_idx_sample]
# 更新计数器
batch_num += chunks_to_add
processed_in_sample += chunks_to_add
# 如果批次满了,yield 一个副本并重置
if batch_num == self.batch_size:
list_x = [
wav_tensor[xi, :, :x_len].reshape(-1).cpu().numpy()
for xi, x_len in enumerate(input_lengths.tolist())
]
yield BatchFeature({
**self.encode_func(list_x),
"chunk_seq_no": input_seq_no.clone(),
})
# 重置批次计数器和Tensor内容
batch_num = 0
wav_tensor.zero_()
input_lengths.zero_()
input_seq_no.zero_()
# 循环结束后,处理最后一个未满的批次
if batch_num > 0:
list_x = [
wav_tensor[xi, :, :x_len].reshape(-1).cpu().numpy()
for xi, x_len in enumerate(input_lengths.tolist())
]
yield BatchFeature({
**self.encode_func(list_x),
"chunk_seq_no": input_seq_no[:batch_num].clone(),
})
class XYTokenizerFeatureExtractor(WhisperFeatureExtractor):
def __init__(
self,
feature_size=80,
sampling_rate=16000,
hop_length=160,
chunk_length=30,
n_fft=400,
padding_value=0.0,
dither=0.0,
return_attention_mask=False,
max_frequency=None,
batch_size=None,
**kwargs,
):
super().__init__(
feature_size=feature_size,
sampling_rate=sampling_rate,
hop_length=hop_length,
chunk_length=chunk_length,
n_fft=n_fft,
padding_value=padding_value,
dither=dither,
return_attention_mask=return_attention_mask,
**kwargs,
)
self.max_frequency = max_frequency if max_frequency is not None else sampling_rate / 2
self.batch_size = batch_size
self.mel_filters = mel_filter_bank(
num_frequency_bins=1 + n_fft // 2,
num_mel_filters=feature_size,
min_frequency=0.0,
max_frequency=self.max_frequency,
sampling_rate=sampling_rate,
norm="slaney",
mel_scale="slaney",
)
def __call__(
self,
raw_speech: Union[torch.Tensor, List[torch.Tensor]],
truncation: bool = True,
pad_to_multiple_of: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_attention_mask: Optional[bool] = None,
padding: Optional[str] = "max_length",
max_length: Optional[int] = None,
sampling_rate: Optional[int] = None,
do_normalize: Optional[bool] = None,
device: Optional[str] = "cpu",
return_token_timestamps: Optional[bool] = None,
overlap_seconds: int = 10,
**kwargs,
) -> ExtractorIterator:
if not isinstance(raw_speech, list):
raw_speech = [raw_speech]
return ExtractorIterator(
raw_speech,
batch_size=len(raw_speech) if self.batch_size is None else self.batch_size,
chunk_length=self.chunk_length,
overlap_seconds=overlap_seconds,
sampling_rate=self.sampling_rate,
encode_func=partial(
super().__call__,
truncation=truncation,
pad_to_multiple_of=pad_to_multiple_of,
return_tensors=return_tensors,
return_attention_mask=return_attention_mask,
padding=padding,
max_length=max_length,
sampling_rate=sampling_rate,
do_normalize=do_normalize,
device=device,
return_token_timestamps=return_token_timestamps,
**kwargs,
)
)