Spaces:
Running
on
L4
Running
on
L4
| # -*- coding: utf-8 -*- | |
| # Time :2025/3/29 10:30 | |
| # Author :Hui Huang | |
| import os | |
| from typing import Literal, Optional, Tuple, Dict, Any, List, Union | |
| import torch | |
| import torchaudio | |
| import torchaudio.transforms as TT | |
| from transformers import Wav2Vec2Model, Wav2Vec2FeatureExtractor | |
| import numpy as np | |
| from loguru import logger | |
| from pathlib import Path | |
| # ----------------- 假设这些模块位于你的项目路径下 ----------------- | |
| from .utils.file import load_config | |
| from .utils.audio import load_audio | |
| from .models.bicodec import BiCodec | |
| from .base_model import SparkBaseModel | |
| from .batch_processor import AsyncBatchEngine | |
| # --------------------------------------------------------------- | |
| __all__ = ["SparkTokenizer"] | |
| class SparkTokenizer: | |
| def __init__( | |
| self, | |
| model_path: str, | |
| device: Literal["cpu", "cuda", "mps"] | str = "cuda", | |
| attn_implementation: Optional[Literal["sdpa", "flash_attention_2", "eager"]] = "eager", | |
| batch_size: int = 32, | |
| wait_timeout: float = 0.01, | |
| ): | |
| self.device = torch.device(device) | |
| self.model_dir = Path(model_path) | |
| # 1. 加载配置 | |
| self.config = load_config(self.model_dir / "config.yaml") | |
| self.device_type = "cuda" if "cuda" in str(device) else "cpu" | |
| self.dtype = torch.float16 if self.device_type == "cuda" else torch.float32 | |
| self.target_sample_rate = self.config.get("sample_rate", 16000) | |
| # 2. 加载模型 | |
| wav2vec_path = self.model_dir / "wav2vec2-large-xlsr-53" | |
| self.processor = Wav2Vec2FeatureExtractor.from_pretrained(wav2vec_path) | |
| self.feature_extractor = Wav2Vec2Model.from_pretrained( | |
| wav2vec_path, | |
| attn_implementation=attn_implementation, | |
| torch_dtype=self.dtype | |
| ) | |
| self.feature_extractor.config.output_hidden_states = True | |
| self.feature_extractor.to(self.device) | |
| self.feature_extractor.eval() | |
| # BiCodec model | |
| self.model = ( | |
| BiCodec.load_from_checkpoint(str(self.model_dir)).to(self.device).half() | |
| ) | |
| self.model.eval() | |
| # 异步处理引擎 | |
| self._batch_processor = AsyncBatchEngine( | |
| processing_function=self.batch_tokenize_async, | |
| batch_size=batch_size, | |
| wait_timeout=wait_timeout | |
| ) | |
| def _to_ndarray(self, audio_input: Union[str, Path, torch.Tensor]) -> np.ndarray: | |
| """ | |
| 将输入(路径或Tensor)统一转换为指定采样率的 numpy 数组。 | |
| """ | |
| if isinstance(audio_input, (str, Path)): | |
| # 如果是路径,直接使用原有的 load_audio | |
| wav = load_audio( | |
| str(audio_input), | |
| sampling_rate=self.target_sample_rate, | |
| volume_normalize=self.config.get("volume_normalize", True), | |
| ) | |
| elif isinstance(audio_input, torch.Tensor): | |
| # 如果是 Tensor | |
| wav = audio_input.detach().cpu().float() | |
| # 处理通道: [C, T] -> [T] | |
| if wav.ndim > 1: | |
| wav = torch.mean(wav, dim=0) | |
| # 这里默认输入的 Tensor 采样率已经是 self.target_sample_rate | |
| # 如果需要在这里做重采样,需要额外传入输入采样率参数 | |
| wav = wav.numpy() | |
| # 可选:音量归一化逻辑(如果 Tensor 没归一化) | |
| if self.config.get("volume_normalize", True): | |
| max_val = np.abs(wav).max() | |
| if max_val > 0: | |
| wav = wav / max_val * 0.9 | |
| else: | |
| raise ValueError(f"Unsupported audio type: {type(audio_input)}") | |
| return wav | |
| def get_ref_clip(self, wav: np.ndarray) -> np.ndarray: | |
| """获取参考音频片段""" | |
| ref_segment_length = ( | |
| int(self.target_sample_rate * self.config["ref_segment_duration"]) | |
| // self.config["latent_hop_length"] | |
| * self.config["latent_hop_length"] | |
| ) | |
| wav_length = len(wav) | |
| if ref_segment_length > wav_length: | |
| wav = np.tile(wav, ref_segment_length // wav_length + 1) | |
| return wav[:ref_segment_length] | |
| def process_audio(self, audio_input: Union[str, torch.Tensor], ref_audio_input: Union[str, torch.Tensor] = None) -> Tuple[np.ndarray, torch.Tensor]: | |
| """ | |
| 处理音频和参考音频。 | |
| """ | |
| wav = self._to_ndarray(audio_input) | |
| if ref_audio_input is None: | |
| wav_ref_np = self.get_ref_clip(wav) | |
| else: | |
| ref_wav = self._to_ndarray(ref_audio_input) | |
| wav_ref_np = self.get_ref_clip(ref_wav) | |
| wav_ref = torch.from_numpy(wav_ref_np).unsqueeze(0).float() | |
| return wav, wav_ref | |
| def extract_wav2vec2_features(self, wavs: torch.Tensor) -> torch.Tensor: | |
| """提取 wav2vec2 特征""" | |
| # processor 期望是 list of numpy | |
| inputs = self.processor( | |
| [w.cpu().numpy() for w in wavs], | |
| sampling_rate=16000, | |
| return_tensors="pt", | |
| padding=True, | |
| ).input_values | |
| with torch.no_grad(): | |
| with torch.amp.autocast(self.device_type, dtype=self.dtype): | |
| feat = self.feature_extractor(inputs.to(self.feature_extractor.device)) | |
| feats_mix = ( | |
| feat.hidden_states[11] + feat.hidden_states[14] + feat.hidden_states[16] | |
| ) / 3 | |
| return feats_mix | |
| def tokenize(self, audios: List[Union[str, torch.Tensor]]): | |
| """ | |
| 支持音频路径列表或 Tensor 列表。 | |
| """ | |
| batch_wavs = [] | |
| batch_ref_wavs = [] | |
| for audio_item in audios: | |
| wav, wav_ref = self.process_audio(audio_input=audio_item, ref_audio_input=audio_item) | |
| batch_wavs.append(torch.from_numpy(wav).float()) | |
| batch_ref_wavs.append(wav_ref.squeeze(0)) | |
| # Padding wavs | |
| wav_lengths = [len(w) for w in batch_wavs] | |
| max_wav_len = max(wav_lengths) | |
| padded_wavs = torch.zeros(len(batch_wavs), max_wav_len, dtype=self.dtype).to(self.device) | |
| for i, w in enumerate(batch_wavs): | |
| padded_wavs[i, :len(w)] = w.to(self.dtype) | |
| # Padding ref_wavs | |
| ref_lengths = [len(w) for w in batch_ref_wavs] | |
| max_ref_len = max(ref_lengths) | |
| padded_ref_wavs = torch.zeros(len(batch_ref_wavs), max_ref_len, dtype=self.dtype).to(self.device) | |
| for i, w in enumerate(batch_ref_wavs): | |
| padded_ref_wavs[i, :len(w)] = w.to(self.dtype) | |
| # 提取特征 | |
| feats = self.extract_wav2vec2_features(padded_wavs) | |
| batch = { | |
| "wav": padded_wavs, | |
| "ref_wav": padded_ref_wavs, | |
| "feat": feats, | |
| } | |
| semantic_tokens, global_tokens = self.model.tokenize(batch) | |
| if self.device.type == "cuda": | |
| torch.cuda.empty_cache() | |
| return {"semantic_tokens": semantic_tokens, "global_tokens": global_tokens} | |
| async def batch_tokenize_async(self, audios: list) -> list[dict[str, torch.Tensor]]: | |
| tokenized = self.tokenize(audios) | |
| responses = [] | |
| for i in range(len(audios)): | |
| responses.append({ | |
| "global_tokens": tokenized["global_tokens"][i], | |
| "semantic_tokens": tokenized["semantic_tokens"][i] | |
| }) | |
| return responses | |
| async def tokenize_async(self, audio: Union[str, torch.Tensor]) -> dict[str, torch.Tensor]: | |
| output = await self._batch_processor.add_request( | |
| single_input=audio | |
| ) | |
| return output | |
| # ------------------------------------------------------------------ | |
| # 测试用例 | |
| # ------------------------------------------------------------------ | |
| if __name__ == "__main__": | |
| # 配置你的模型路径 | |
| MODEL_DIR = "/data/yumu/model/ark_tts_v1" | |
| # 初始化 | |
| # 注意:在没有真实环境时,这行会因为找不到文件报错,请在有环境的地方运行 | |
| tokenizer = SparkTokenizer(model_path=MODEL_DIR, device="cuda" if torch.cuda.is_available() else "cpu") | |
| # 准备数据:一个是本地存在的 wav 路径,一个是构造的 Tensor | |
| dummy_wav_path = "/data/yumu/arktts/dufu.wav" | |
| # 构造一个 16kHz 的 2 秒音频 Tensor (假设模型要求16k) | |
| import torchaudio | |
| dummy_tensor, sr = torchaudio.load(dummy_wav_path) | |
| # 1. 测试路径输入 | |
| print("Testing path input...") | |
| if os.path.exists(dummy_wav_path): | |
| res1 = tokenizer.tokenize([dummy_wav_path]) | |
| print(f"Path results: {res1['semantic_tokens'].shape}") | |
| # 2. 测试 Tensor 输入 | |
| print("Testing tensor input...") | |
| res2 = tokenizer.tokenize([dummy_tensor]) | |
| print(f"Tensor results: {res2['semantic_tokens'].shape}") | |
| # 3. 测试混合输入 (List 包含 str 和 Tensor) | |
| print("Testing mixed input...") | |
| # 为了演示,我们传两个相同的 tensor | |
| res3 = tokenizer.tokenize([dummy_tensor, dummy_tensor]) | |
| print(f"Mixed results: {res3['semantic_tokens'].shape}") | |
| print("All tests passed!") | |