GPA_DEMO / models /bicodec_tokenizer /spark_tokenizer.py
wanglamao
init
528efee
# -*- coding: utf-8 -*-
# Time :2025/3/29 10:30
# Author :Hui Huang
import os
from typing import Literal, Optional, Tuple, Dict, Any, List, Union
import torch
import torchaudio
import torchaudio.transforms as TT
from transformers import Wav2Vec2Model, Wav2Vec2FeatureExtractor
import numpy as np
from loguru import logger
from pathlib import Path
# ----------------- 假设这些模块位于你的项目路径下 -----------------
from .utils.file import load_config
from .utils.audio import load_audio
from .models.bicodec import BiCodec
from .base_model import SparkBaseModel
from .batch_processor import AsyncBatchEngine
# ---------------------------------------------------------------
__all__ = ["SparkTokenizer"]
class SparkTokenizer:
def __init__(
self,
model_path: str,
device: Literal["cpu", "cuda", "mps"] | str = "cuda",
attn_implementation: Optional[Literal["sdpa", "flash_attention_2", "eager"]] = "eager",
batch_size: int = 32,
wait_timeout: float = 0.01,
):
self.device = torch.device(device)
self.model_dir = Path(model_path)
# 1. 加载配置
self.config = load_config(self.model_dir / "config.yaml")
self.device_type = "cuda" if "cuda" in str(device) else "cpu"
self.dtype = torch.float16 if self.device_type == "cuda" else torch.float32
self.target_sample_rate = self.config.get("sample_rate", 16000)
# 2. 加载模型
wav2vec_path = self.model_dir / "wav2vec2-large-xlsr-53"
self.processor = Wav2Vec2FeatureExtractor.from_pretrained(wav2vec_path)
self.feature_extractor = Wav2Vec2Model.from_pretrained(
wav2vec_path,
attn_implementation=attn_implementation,
torch_dtype=self.dtype
)
self.feature_extractor.config.output_hidden_states = True
self.feature_extractor.to(self.device)
self.feature_extractor.eval()
# BiCodec model
self.model = (
BiCodec.load_from_checkpoint(str(self.model_dir)).to(self.device).half()
)
self.model.eval()
# 异步处理引擎
self._batch_processor = AsyncBatchEngine(
processing_function=self.batch_tokenize_async,
batch_size=batch_size,
wait_timeout=wait_timeout
)
def _to_ndarray(self, audio_input: Union[str, Path, torch.Tensor]) -> np.ndarray:
"""
将输入(路径或Tensor)统一转换为指定采样率的 numpy 数组。
"""
if isinstance(audio_input, (str, Path)):
# 如果是路径,直接使用原有的 load_audio
wav = load_audio(
str(audio_input),
sampling_rate=self.target_sample_rate,
volume_normalize=self.config.get("volume_normalize", True),
)
elif isinstance(audio_input, torch.Tensor):
# 如果是 Tensor
wav = audio_input.detach().cpu().float()
# 处理通道: [C, T] -> [T]
if wav.ndim > 1:
wav = torch.mean(wav, dim=0)
# 这里默认输入的 Tensor 采样率已经是 self.target_sample_rate
# 如果需要在这里做重采样,需要额外传入输入采样率参数
wav = wav.numpy()
# 可选:音量归一化逻辑(如果 Tensor 没归一化)
if self.config.get("volume_normalize", True):
max_val = np.abs(wav).max()
if max_val > 0:
wav = wav / max_val * 0.9
else:
raise ValueError(f"Unsupported audio type: {type(audio_input)}")
return wav
def get_ref_clip(self, wav: np.ndarray) -> np.ndarray:
"""获取参考音频片段"""
ref_segment_length = (
int(self.target_sample_rate * self.config["ref_segment_duration"])
// self.config["latent_hop_length"]
* self.config["latent_hop_length"]
)
wav_length = len(wav)
if ref_segment_length > wav_length:
wav = np.tile(wav, ref_segment_length // wav_length + 1)
return wav[:ref_segment_length]
def process_audio(self, audio_input: Union[str, torch.Tensor], ref_audio_input: Union[str, torch.Tensor] = None) -> Tuple[np.ndarray, torch.Tensor]:
"""
处理音频和参考音频。
"""
wav = self._to_ndarray(audio_input)
if ref_audio_input is None:
wav_ref_np = self.get_ref_clip(wav)
else:
ref_wav = self._to_ndarray(ref_audio_input)
wav_ref_np = self.get_ref_clip(ref_wav)
wav_ref = torch.from_numpy(wav_ref_np).unsqueeze(0).float()
return wav, wav_ref
def extract_wav2vec2_features(self, wavs: torch.Tensor) -> torch.Tensor:
"""提取 wav2vec2 特征"""
# processor 期望是 list of numpy
inputs = self.processor(
[w.cpu().numpy() for w in wavs],
sampling_rate=16000,
return_tensors="pt",
padding=True,
).input_values
with torch.no_grad():
with torch.amp.autocast(self.device_type, dtype=self.dtype):
feat = self.feature_extractor(inputs.to(self.feature_extractor.device))
feats_mix = (
feat.hidden_states[11] + feat.hidden_states[14] + feat.hidden_states[16]
) / 3
return feats_mix
@torch.no_grad()
def tokenize(self, audios: List[Union[str, torch.Tensor]]):
"""
支持音频路径列表或 Tensor 列表。
"""
batch_wavs = []
batch_ref_wavs = []
for audio_item in audios:
wav, wav_ref = self.process_audio(audio_input=audio_item, ref_audio_input=audio_item)
batch_wavs.append(torch.from_numpy(wav).float())
batch_ref_wavs.append(wav_ref.squeeze(0))
# Padding wavs
wav_lengths = [len(w) for w in batch_wavs]
max_wav_len = max(wav_lengths)
padded_wavs = torch.zeros(len(batch_wavs), max_wav_len, dtype=self.dtype).to(self.device)
for i, w in enumerate(batch_wavs):
padded_wavs[i, :len(w)] = w.to(self.dtype)
# Padding ref_wavs
ref_lengths = [len(w) for w in batch_ref_wavs]
max_ref_len = max(ref_lengths)
padded_ref_wavs = torch.zeros(len(batch_ref_wavs), max_ref_len, dtype=self.dtype).to(self.device)
for i, w in enumerate(batch_ref_wavs):
padded_ref_wavs[i, :len(w)] = w.to(self.dtype)
# 提取特征
feats = self.extract_wav2vec2_features(padded_wavs)
batch = {
"wav": padded_wavs,
"ref_wav": padded_ref_wavs,
"feat": feats,
}
semantic_tokens, global_tokens = self.model.tokenize(batch)
if self.device.type == "cuda":
torch.cuda.empty_cache()
return {"semantic_tokens": semantic_tokens, "global_tokens": global_tokens}
async def batch_tokenize_async(self, audios: list) -> list[dict[str, torch.Tensor]]:
tokenized = self.tokenize(audios)
responses = []
for i in range(len(audios)):
responses.append({
"global_tokens": tokenized["global_tokens"][i],
"semantic_tokens": tokenized["semantic_tokens"][i]
})
return responses
async def tokenize_async(self, audio: Union[str, torch.Tensor]) -> dict[str, torch.Tensor]:
output = await self._batch_processor.add_request(
single_input=audio
)
return output
# ------------------------------------------------------------------
# 测试用例
# ------------------------------------------------------------------
if __name__ == "__main__":
# 配置你的模型路径
MODEL_DIR = "/data/yumu/model/ark_tts_v1"
# 初始化
# 注意:在没有真实环境时,这行会因为找不到文件报错,请在有环境的地方运行
tokenizer = SparkTokenizer(model_path=MODEL_DIR, device="cuda" if torch.cuda.is_available() else "cpu")
# 准备数据:一个是本地存在的 wav 路径,一个是构造的 Tensor
dummy_wav_path = "/data/yumu/arktts/dufu.wav"
# 构造一个 16kHz 的 2 秒音频 Tensor (假设模型要求16k)
import torchaudio
dummy_tensor, sr = torchaudio.load(dummy_wav_path)
# 1. 测试路径输入
print("Testing path input...")
if os.path.exists(dummy_wav_path):
res1 = tokenizer.tokenize([dummy_wav_path])
print(f"Path results: {res1['semantic_tokens'].shape}")
# 2. 测试 Tensor 输入
print("Testing tensor input...")
res2 = tokenizer.tokenize([dummy_tensor])
print(f"Tensor results: {res2['semantic_tokens'].shape}")
# 3. 测试混合输入 (List 包含 str 和 Tensor)
print("Testing mixed input...")
# 为了演示,我们传两个相同的 tensor
res3 = tokenizer.tokenize([dummy_tensor, dummy_tensor])
print(f"Mixed results: {res3['semantic_tokens'].shape}")
print("All tests passed!")