File size: 9,131 Bytes
528efee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
# -*- coding: utf-8 -*-
# Time      :2025/3/29 10:30
# Author    :Hui Huang
import os
from typing import Literal, Optional, Tuple, Dict, Any, List, Union

import torch
import torchaudio
import torchaudio.transforms as TT
from transformers import Wav2Vec2Model, Wav2Vec2FeatureExtractor
import numpy as np
from loguru import logger
from pathlib import Path

# ----------------- 假设这些模块位于你的项目路径下 -----------------
from .utils.file import load_config
from .utils.audio import load_audio
from .models.bicodec import BiCodec
from .base_model import SparkBaseModel
from .batch_processor import AsyncBatchEngine
# ---------------------------------------------------------------

__all__ = ["SparkTokenizer"]


class SparkTokenizer:
    def __init__(
            self,
            model_path: str,
            device: Literal["cpu", "cuda", "mps"] | str = "cuda",
            attn_implementation: Optional[Literal["sdpa", "flash_attention_2", "eager"]] = "eager",
            batch_size: int = 32,
            wait_timeout: float = 0.01,
    ):
        self.device = torch.device(device)
        self.model_dir = Path(model_path)

        # 1. 加载配置
        self.config = load_config(self.model_dir / "config.yaml")
        self.device_type = "cuda" if "cuda" in str(device) else "cpu"
        self.dtype = torch.float16 if self.device_type == "cuda" else torch.float32
        self.target_sample_rate = self.config.get("sample_rate", 16000)

        # 2. 加载模型
        wav2vec_path = self.model_dir / "wav2vec2-large-xlsr-53"
        self.processor = Wav2Vec2FeatureExtractor.from_pretrained(wav2vec_path)
        self.feature_extractor = Wav2Vec2Model.from_pretrained(
            wav2vec_path,
            attn_implementation=attn_implementation,
            torch_dtype=self.dtype 
        )
        self.feature_extractor.config.output_hidden_states = True
        self.feature_extractor.to(self.device)
        self.feature_extractor.eval()

        # BiCodec model
        self.model = (
            BiCodec.load_from_checkpoint(str(self.model_dir)).to(self.device).half()
        )
        self.model.eval()

        # 异步处理引擎
        self._batch_processor = AsyncBatchEngine(
            processing_function=self.batch_tokenize_async,
            batch_size=batch_size,
            wait_timeout=wait_timeout
        )

    def _to_ndarray(self, audio_input: Union[str, Path, torch.Tensor]) -> np.ndarray:
        """
        将输入(路径或Tensor)统一转换为指定采样率的 numpy 数组。
        """
        if isinstance(audio_input, (str, Path)):
            # 如果是路径,直接使用原有的 load_audio
            wav = load_audio(
                str(audio_input),
                sampling_rate=self.target_sample_rate,
                volume_normalize=self.config.get("volume_normalize", True),
            )
        elif isinstance(audio_input, torch.Tensor):
            # 如果是 Tensor
            wav = audio_input.detach().cpu().float()

            # 处理通道: [C, T] -> [T]
            if wav.ndim > 1:
                wav = torch.mean(wav, dim=0)

            # 这里默认输入的 Tensor 采样率已经是 self.target_sample_rate
            # 如果需要在这里做重采样,需要额外传入输入采样率参数
            wav = wav.numpy()

            # 可选:音量归一化逻辑(如果 Tensor 没归一化)
            if self.config.get("volume_normalize", True):
                max_val = np.abs(wav).max()
                if max_val > 0:
                    wav = wav / max_val * 0.9
        else:
            raise ValueError(f"Unsupported audio type: {type(audio_input)}")

        return wav

    def get_ref_clip(self, wav: np.ndarray) -> np.ndarray:
        """获取参考音频片段"""
        ref_segment_length = (
            int(self.target_sample_rate * self.config["ref_segment_duration"])
            // self.config["latent_hop_length"]
            * self.config["latent_hop_length"]
        )
        wav_length = len(wav)

        if ref_segment_length > wav_length:
            wav = np.tile(wav, ref_segment_length // wav_length + 1)

        return wav[:ref_segment_length]

    def process_audio(self, audio_input: Union[str, torch.Tensor], ref_audio_input: Union[str, torch.Tensor] = None) -> Tuple[np.ndarray, torch.Tensor]:
        """
        处理音频和参考音频。
        """
        wav = self._to_ndarray(audio_input)

        if ref_audio_input is None:
            wav_ref_np = self.get_ref_clip(wav)
        else:
            ref_wav = self._to_ndarray(ref_audio_input)
            wav_ref_np = self.get_ref_clip(ref_wav)

        wav_ref = torch.from_numpy(wav_ref_np).unsqueeze(0).float()
        return wav, wav_ref

    def extract_wav2vec2_features(self, wavs: torch.Tensor) -> torch.Tensor:
        """提取 wav2vec2 特征"""
        # processor 期望是 list of numpy
        inputs = self.processor(
            [w.cpu().numpy() for w in wavs], 
            sampling_rate=16000,
            return_tensors="pt",
            padding=True,
        ).input_values

        with torch.no_grad():
            with torch.amp.autocast(self.device_type, dtype=self.dtype):
                feat = self.feature_extractor(inputs.to(self.feature_extractor.device))

        feats_mix = (
            feat.hidden_states[11] + feat.hidden_states[14] + feat.hidden_states[16]
        ) / 3

        return feats_mix

    @torch.no_grad()
    def tokenize(self, audios: List[Union[str, torch.Tensor]]):
        """
        支持音频路径列表或 Tensor 列表。
        """
        batch_wavs = []
        batch_ref_wavs = []

        for audio_item in audios:
            wav, wav_ref = self.process_audio(audio_input=audio_item, ref_audio_input=audio_item)
            batch_wavs.append(torch.from_numpy(wav).float())
            batch_ref_wavs.append(wav_ref.squeeze(0))

        # Padding wavs
        wav_lengths = [len(w) for w in batch_wavs]
        max_wav_len = max(wav_lengths)
        padded_wavs = torch.zeros(len(batch_wavs), max_wav_len, dtype=self.dtype).to(self.device)
        for i, w in enumerate(batch_wavs):
            padded_wavs[i, :len(w)] = w.to(self.dtype)

        # Padding ref_wavs
        ref_lengths = [len(w) for w in batch_ref_wavs]
        max_ref_len = max(ref_lengths)
        padded_ref_wavs = torch.zeros(len(batch_ref_wavs), max_ref_len, dtype=self.dtype).to(self.device)
        for i, w in enumerate(batch_ref_wavs):
            padded_ref_wavs[i, :len(w)] = w.to(self.dtype)

        # 提取特征
        feats = self.extract_wav2vec2_features(padded_wavs)

        batch = {
            "wav": padded_wavs,
            "ref_wav": padded_ref_wavs,
            "feat": feats,
        }

        semantic_tokens, global_tokens = self.model.tokenize(batch)

        if self.device.type == "cuda":
            torch.cuda.empty_cache()

        return {"semantic_tokens": semantic_tokens, "global_tokens": global_tokens}

    async def batch_tokenize_async(self, audios: list) -> list[dict[str, torch.Tensor]]:
        tokenized = self.tokenize(audios)
        responses = []
        for i in range(len(audios)):
            responses.append({
                "global_tokens": tokenized["global_tokens"][i],
                "semantic_tokens": tokenized["semantic_tokens"][i]
            })
        return responses

    async def tokenize_async(self, audio: Union[str, torch.Tensor]) -> dict[str, torch.Tensor]:
        output = await self._batch_processor.add_request(
            single_input=audio
        )
        return output

# ------------------------------------------------------------------
# 测试用例
# ------------------------------------------------------------------
if __name__ == "__main__":
    # 配置你的模型路径
    MODEL_DIR = "/data/yumu/model/ark_tts_v1"
    
    # 初始化
    # 注意:在没有真实环境时,这行会因为找不到文件报错,请在有环境的地方运行
    tokenizer = SparkTokenizer(model_path=MODEL_DIR, device="cuda" if torch.cuda.is_available() else "cpu")
    
    # 准备数据:一个是本地存在的 wav 路径,一个是构造的 Tensor
    dummy_wav_path = "/data/yumu/arktts/dufu.wav" 
    # 构造一个 16kHz 的 2 秒音频 Tensor (假设模型要求16k)
    import torchaudio
    dummy_tensor, sr = torchaudio.load(dummy_wav_path)

    # 1. 测试路径输入
    print("Testing path input...")
    if os.path.exists(dummy_wav_path):
        res1 = tokenizer.tokenize([dummy_wav_path])
        print(f"Path results: {res1['semantic_tokens'].shape}")

    # 2. 测试 Tensor 输入
    print("Testing tensor input...")
    res2 = tokenizer.tokenize([dummy_tensor])
    print(f"Tensor results: {res2['semantic_tokens'].shape}")

    # 3. 测试混合输入 (List 包含 str 和 Tensor)
    print("Testing mixed input...")
    # 为了演示,我们传两个相同的 tensor
    res3 = tokenizer.tokenize([dummy_tensor, dummy_tensor])
    print(f"Mixed results: {res3['semantic_tokens'].shape}")
    
    print("All tests passed!")