File size: 5,068 Bytes
528efee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import os
import sys
sys.path.append("../../..")
import io
import glob
import math
import tarfile
import torch
import torchaudio
import safetensors
from .configuration_whisper import WhisperVQConfig
from .modeling_whisper import WhisperVQEncoder, WhisperVQForConditionalGeneration
from transformers import WhisperFeatureExtractor, WhisperTokenizerFast
import asyncio
from .batch_processor import AsyncBatchEngine  # 修改为你的路径
from typing import List, Union, Tuple, Literal, Optional


class SpeechTokenExtractor:
    def __init__(
        self,
        model: WhisperVQEncoder,
        feature_extractor: WhisperFeatureExtractor,
        device: Literal["cpu", "cuda", "mps"] | str = "cuda",
        batch_size: int = 32,
        wait_timeout: float = 0.01,
    ):
        self.model = model.eval().to(device)
        self.feature_extractor = feature_extractor
        self.device = device
        self.wait_timeout = wait_timeout
        self.dtype = next(model.parameters()).dtype

        # 帧/采样 stride(用于 pad 对齐 & mask 下采样)
        self.pooling_kernel_size = getattr(model.config, "pooling_kernel_size", 1)
        self.frame_stride = (
            model.conv1.stride[0] *
            model.conv2.stride[0] *
            self.pooling_kernel_size
        )
        self.sample_stride = self.frame_stride * feature_extractor.hop_length

        # 重采样缓存(放在 device 上)
        self._resamplers: dict[int, torchaudio.transforms.Resample] = {}

        self._batch_processor = AsyncBatchEngine(
            processing_function=self._batch_extract_async,
            batch_size=batch_size,
            wait_timeout=wait_timeout,
        )

    # -------- I/O & 重采样:保持在 device 上 --------
    def _load_audio(self, utt: Union[str, torch.Tensor]) -> torch.Tensor:
        """读取单条音频 -> 1D float32 waveform(在 self.device 上,采样率16k)。"""
        # print(f"audio type is {type(utt)}")
        if isinstance(utt, torch.Tensor):
            # audio, sr = utt
            audio = utt.to(self.device, non_blocking=True)
        else:
            audio, sr = torchaudio.load(utt)          # CPU
            if audio.ndim > 1 and audio.size(0) > 1:  # 混单声道
                audio = audio.mean(dim=0, keepdim=True)
            audio = audio.squeeze(0).to(torch.float32).to(self.device, non_blocking=True)

        return audio  # [T] on device

    # -------- GPU 上做 feature_extractor --------
    def _extract_features_gpu(self, audios: List[torch.Tensor]) -> dict:
        """
        1) 输入统一转 CPU numpy(float32)(FE 的要求)
        2) 调用 FE,并传 device=self.device,让“输出张量”直接落在 GPU
        3) 若模型是 fp16,仅将 input_features 转 half(mask 不动)
        """
        # 1) CUDA/CPU Tensor -> CPU numpy
        np_audios = [a.detach().cpu().numpy().astype("float32") for a in audios]

       
        feats = self.feature_extractor(
            np_audios,
            sampling_rate=16000,
            return_attention_mask=True,
            return_tensors="pt",
            device=self.device,                  # ← 用得上
            padding="longest",
            pad_to_multiple_of=self.sample_stride,
        )
    
        feats = {k: (v.to(self.device, non_blocking=True) if isinstance(v, torch.Tensor) else v)
                for k, v in feats.items()}

        # 3) 半精度对齐(只对 input_features)
        if self.dtype == torch.float16 and "input_features" in feats:
            feats["input_features"] = feats["input_features"].half()

        return feats


    def _forward(self, feats: dict) -> List[List[int]]:
        outputs = self.model(**feats)
        tokens = outputs.quantized_token_ids
        # mask 下采样对齐:conv 下采样 × pooling
        attn = feats["attention_mask"][
            :, :: self.model.conv1.stride[0] * self.model.conv2.stride[0]
        ][:, :: self.pooling_kernel_size]
        return [t[m.bool()].tolist() for t, m in zip(tokens, attn)]

    # -------- 同步批接口 --------
    def extract(self, utts: List[Union[str, torch.Tensor]]) -> List[List[int]]:
        """
        不做 30s 分片,也不做 microbatch。
        直接:加载/重采样 -> GPU 特征提取 -> 前向 -> 对齐输出。
        """
        audios = [self._load_audio(u) for u in utts]          # list[Tensor(T)] on device
        with torch.inference_mode():
            feats = self._extract_features_gpu(audios)        # on device
            return self._forward(feats)

    # -------- 异步批接口(保持你的返回协议)--------
    async def _batch_extract_async(self, utts: List[Union[str, torch.Tensor]]):
        tokens_list = await asyncio.to_thread(self.extract, utts)
        return [{"tokens": t} for t in tokens_list]

    async def extract_async(self, utt: Union[str, torch.Tensor]):
        result = await self._batch_processor.add_request(single_input=utt)
        feature = result.get("feature")
        return feature.get("tokens")