Spaces:
Running
on
L4
Running
on
L4
| import os | |
| import sys | |
| sys.path.append("../../..") | |
| import io | |
| import glob | |
| import math | |
| import tarfile | |
| import torch | |
| import torchaudio | |
| import safetensors | |
| from .configuration_whisper import WhisperVQConfig | |
| from .modeling_whisper import WhisperVQEncoder, WhisperVQForConditionalGeneration | |
| from transformers import WhisperFeatureExtractor, WhisperTokenizerFast | |
| import asyncio | |
| from .batch_processor import AsyncBatchEngine # 修改为你的路径 | |
| from typing import List, Union, Tuple, Literal, Optional | |
| class SpeechTokenExtractor: | |
| def __init__( | |
| self, | |
| model: WhisperVQEncoder, | |
| feature_extractor: WhisperFeatureExtractor, | |
| device: Literal["cpu", "cuda", "mps"] | str = "cuda", | |
| batch_size: int = 32, | |
| wait_timeout: float = 0.01, | |
| ): | |
| self.model = model.eval().to(device) | |
| self.feature_extractor = feature_extractor | |
| self.device = device | |
| self.wait_timeout = wait_timeout | |
| self.dtype = next(model.parameters()).dtype | |
| # 帧/采样 stride(用于 pad 对齐 & mask 下采样) | |
| self.pooling_kernel_size = getattr(model.config, "pooling_kernel_size", 1) | |
| self.frame_stride = ( | |
| model.conv1.stride[0] * | |
| model.conv2.stride[0] * | |
| self.pooling_kernel_size | |
| ) | |
| self.sample_stride = self.frame_stride * feature_extractor.hop_length | |
| # 重采样缓存(放在 device 上) | |
| self._resamplers: dict[int, torchaudio.transforms.Resample] = {} | |
| self._batch_processor = AsyncBatchEngine( | |
| processing_function=self._batch_extract_async, | |
| batch_size=batch_size, | |
| wait_timeout=wait_timeout, | |
| ) | |
| # -------- I/O & 重采样:保持在 device 上 -------- | |
| def _load_audio(self, utt: Union[str, torch.Tensor]) -> torch.Tensor: | |
| """读取单条音频 -> 1D float32 waveform(在 self.device 上,采样率16k)。""" | |
| # print(f"audio type is {type(utt)}") | |
| if isinstance(utt, torch.Tensor): | |
| # audio, sr = utt | |
| audio = utt.to(self.device, non_blocking=True) | |
| else: | |
| audio, sr = torchaudio.load(utt) # CPU | |
| if audio.ndim > 1 and audio.size(0) > 1: # 混单声道 | |
| audio = audio.mean(dim=0, keepdim=True) | |
| audio = audio.squeeze(0).to(torch.float32).to(self.device, non_blocking=True) | |
| return audio # [T] on device | |
| # -------- GPU 上做 feature_extractor -------- | |
| def _extract_features_gpu(self, audios: List[torch.Tensor]) -> dict: | |
| """ | |
| 1) 输入统一转 CPU numpy(float32)(FE 的要求) | |
| 2) 调用 FE,并传 device=self.device,让“输出张量”直接落在 GPU | |
| 3) 若模型是 fp16,仅将 input_features 转 half(mask 不动) | |
| """ | |
| # 1) CUDA/CPU Tensor -> CPU numpy | |
| np_audios = [a.detach().cpu().numpy().astype("float32") for a in audios] | |
| feats = self.feature_extractor( | |
| np_audios, | |
| sampling_rate=16000, | |
| return_attention_mask=True, | |
| return_tensors="pt", | |
| device=self.device, # ← 用得上 | |
| padding="longest", | |
| pad_to_multiple_of=self.sample_stride, | |
| ) | |
| feats = {k: (v.to(self.device, non_blocking=True) if isinstance(v, torch.Tensor) else v) | |
| for k, v in feats.items()} | |
| # 3) 半精度对齐(只对 input_features) | |
| if self.dtype == torch.float16 and "input_features" in feats: | |
| feats["input_features"] = feats["input_features"].half() | |
| return feats | |
| def _forward(self, feats: dict) -> List[List[int]]: | |
| outputs = self.model(**feats) | |
| tokens = outputs.quantized_token_ids | |
| # mask 下采样对齐:conv 下采样 × pooling | |
| attn = feats["attention_mask"][ | |
| :, :: self.model.conv1.stride[0] * self.model.conv2.stride[0] | |
| ][:, :: self.pooling_kernel_size] | |
| return [t[m.bool()].tolist() for t, m in zip(tokens, attn)] | |
| # -------- 同步批接口 -------- | |
| def extract(self, utts: List[Union[str, torch.Tensor]]) -> List[List[int]]: | |
| """ | |
| 不做 30s 分片,也不做 microbatch。 | |
| 直接:加载/重采样 -> GPU 特征提取 -> 前向 -> 对齐输出。 | |
| """ | |
| audios = [self._load_audio(u) for u in utts] # list[Tensor(T)] on device | |
| with torch.inference_mode(): | |
| feats = self._extract_features_gpu(audios) # on device | |
| return self._forward(feats) | |
| # -------- 异步批接口(保持你的返回协议)-------- | |
| async def _batch_extract_async(self, utts: List[Union[str, torch.Tensor]]): | |
| tokens_list = await asyncio.to_thread(self.extract, utts) | |
| return [{"tokens": t} for t in tokens_list] | |
| async def extract_async(self, utt: Union[str, torch.Tensor]): | |
| result = await self._batch_processor.add_request(single_input=utt) | |
| feature = result.get("feature") | |
| return feature.get("tokens") |