Spaces:
Running
on
L4
Running
on
L4
| import os | |
| import re | |
| from models.bicodec_tokenizer.spark_tokenizer import SparkTokenizer | |
| from models.glm_speech_tokenizer.speech_token_extractor import SpeechTokenExtractor | |
| from models.glm_speech_tokenizer.modeling_whisper import WhisperVQEncoder | |
| from transformers import PreTrainedTokenizer,AutoTokenizer,WhisperFeatureExtractor | |
| import torch | |
| import torch.nn.functional as F | |
| import logging | |
| from typing import List, Dict, Any, Literal, Optional, Union | |
| from datasets import load_dataset | |
| from torch.utils.data import DataLoader | |
| def has_punctuation(text: str) -> bool: | |
| # 包含中英文符号 | |
| pattern = r"[,。!?;:()“”‘’、,.!?;:()\[\]{}\"']" | |
| return bool(re.search(pattern, text)) | |
| ALL_TASKS = ["stt", "tts-a", "vc"] | |
| class ark_infer_processor: | |
| def __init__( | |
| self, | |
| glm_tokenizer: SpeechTokenExtractor, | |
| bicodec_tokenizer: SparkTokenizer, | |
| text_tokenizer: PreTrainedTokenizer, | |
| max_length: int = 512, | |
| glm_semantic_token_offset: int = 151727, | |
| semantic_token_offset: int = 172207, | |
| global_token_offset: int = 168111, | |
| audio_path_name: str = "audio", | |
| device: str = "cpu", | |
| ): | |
| self.glm_tokenizer = glm_tokenizer | |
| self.bicodec_tokenizer = bicodec_tokenizer | |
| self.text_tokenizer = text_tokenizer | |
| self.max_length = max_length | |
| self.glm_semantic_token_offset = glm_semantic_token_offset | |
| self.semantic_token_offset = semantic_token_offset | |
| self.global_token_offset = global_token_offset | |
| self.device = device | |
| self.audio_path_name = audio_path_name | |
| def _process_example_stt(self, audio_path: str): | |
| ##target 音频 | |
| with torch.no_grad(): | |
| glm_semantic_tokens = self.glm_tokenizer.extract([audio_path]) | |
| glm_semantic_tokens = torch.as_tensor(glm_semantic_tokens, device="cpu", dtype=torch.long) | |
| semantic_tokens = self.bicodec_tokenizer.tokenize([audio_path])['semantic_tokens'] | |
| glm_semantic_tokens_list = ( | |
| (glm_semantic_tokens + self.glm_semantic_token_offset).cpu().tolist()[0] | |
| ) | |
| semantic_tokens_list = ( | |
| (semantic_tokens + self.semantic_token_offset).cpu().tolist()[0] | |
| ) | |
| input_ids = ( | |
| self.text_tokenizer.encode("<|start_glm_token|>") | |
| + glm_semantic_tokens_list | |
| + self.text_tokenizer.encode("<|end_glm_token|>") | |
| + self.text_tokenizer.encode("<|start_semantic_token|>") | |
| + semantic_tokens_list | |
| + self.text_tokenizer.encode("<|end_semantic_token|>") | |
| + self.text_tokenizer.encode("<|start_content|>") | |
| ) | |
| attention_mask = [1] * (len(input_ids)) | |
| return input_ids, attention_mask | |
| def _process_example_tts_a(self, text: str, ref_audio_path: str): | |
| with torch.no_grad(): | |
| global_tokens = self.bicodec_tokenizer.tokenize([ref_audio_path])['global_tokens'] | |
| all_text = "<|start_content|>" + text + "<|end_content|>" | |
| global_tokens_list = ( | |
| (global_tokens + self.global_token_offset).cpu().tolist()[0][0] | |
| ) | |
| text_tokens = self.text_tokenizer( | |
| all_text, truncation=True, max_length=self.max_length | |
| ) | |
| input_ids = ( | |
| self.text_tokenizer.encode("<|start_global_token|>") | |
| + global_tokens_list | |
| + self.text_tokenizer.encode("<|end_global_token|>") | |
| + text_tokens["input_ids"] | |
| ) | |
| attention_mask = [1] * len(input_ids) | |
| return input_ids, attention_mask | |
| def _process_example_vc(self, audio_path: str, ref_audio_path: str): | |
| with torch.no_grad(): | |
| semantic_tokens = self.bicodec_tokenizer.tokenize([audio_path])['semantic_tokens'] | |
| new_global_tokens = self.bicodec_tokenizer.tokenize([ref_audio_path])['global_tokens'] | |
| semantic_tokens_list = ( | |
| (semantic_tokens + self.semantic_token_offset).cpu().tolist()[0] | |
| ) | |
| new_global_tokens_list = ( | |
| (new_global_tokens + self.global_token_offset).cpu().tolist()[0][0] | |
| ) | |
| all_str = ( | |
| "<|start_global_token|>" | |
| + self.text_tokenizer.decode(new_global_tokens_list) | |
| + "<|end_global_token|>" | |
| + "<|start_semantic_token|>" | |
| + self.text_tokenizer.decode(semantic_tokens_list) | |
| + "<|end_semantic_token|>" | |
| + "<|end_content|>" | |
| ) | |
| inputs = self.text_tokenizer(all_str) | |
| input_ids = inputs["input_ids"] | |
| attention_mask = inputs["attention_mask"] | |
| return input_ids, attention_mask | |
| def process_input( | |
| self, | |
| task: Literal["stt", "tts-a", "vc"], | |
| audio_path: str | None = None, | |
| ref_audio_path: str | None = None, | |
| text: str | None = None, | |
| ): | |
| """加载指定音频、特征并根据任务类型返回 token 化结果。""" | |
| if task == "stt": | |
| assert audio_path is not None | |
| input_ids, attention_mask = self._process_example_stt(audio_path) | |
| elif task == "tts-a": | |
| assert ref_audio_path is not None and text is not None | |
| input_ids, attention_mask = self._process_example_tts_a( | |
| text, ref_audio_path | |
| ) | |
| elif task == "vc": | |
| assert audio_path is not None and ref_audio_path is not None | |
| input_ids, attention_mask = self._process_example_vc( | |
| audio_path, ref_audio_path | |
| ) | |
| else: | |
| raise ValueError( | |
| f"Unsupported task: {task}, all supported tasks: {ALL_TASKS}" | |
| ) | |
| return { | |
| "input_ids": input_ids, | |
| "attention_mask": attention_mask, | |
| } | |
| class ark_processor: | |
| def __init__(self, | |
| glm_tokenizer: SpeechTokenExtractor, | |
| bicodec_tokenizer: SparkTokenizer, | |
| text_tokenizer:PreTrainedTokenizer, | |
| max_length:int = 512, | |
| glm_semantic_token_offset:int = 151727, | |
| semantic_token_offset: int =172207, | |
| global_token_offset: int =168111, | |
| audio_path_name:str = "audio", | |
| device:str ='cpu'): | |
| self.glm_tokenizer = glm_tokenizer | |
| self.bicodec_tokenizer = bicodec_tokenizer | |
| self.text_tokenizer = text_tokenizer | |
| self.max_length = max_length | |
| self.glm_semantic_token_offset =glm_semantic_token_offset | |
| self.semantic_token_offset=semantic_token_offset | |
| self.global_token_offset=global_token_offset | |
| self.device = device | |
| self.audio_path_name =audio_path_name | |
| def process_example(self, example: Dict[str, Any]): | |
| """ | |
| 这个函数由多个CPU进程并行执行。 | |
| 它负责加载、重采样和对单个样本进行特征提取/分词。 | |
| """ | |
| task = example.get("task", "stt") | |
| audio_path = example.get(self.audio_path_name, "") | |
| ref_audio_path = example.get("ref_audio", "") | |
| vc_audio = example.get("vc_audio", "") | |
| text = example.get("text", "") | |
| if task == "stt": | |
| ##target 音频 | |
| with torch.no_grad(): | |
| glm_semantic_tokens = self.glm_tokenizer.extract([audio_path]) | |
| glm_semantic_tokens = torch.as_tensor(glm_semantic_tokens, device="cpu", dtype=torch.long) | |
| semantic_tokens = self.bicodec_tokenizer.tokenize([audio_path])['semantic_tokens'] | |
| glm_semantic_tokens_list = (glm_semantic_tokens + self.glm_semantic_token_offset).cpu().tolist()[0] | |
| semantic_tokens_list = (semantic_tokens + self.semantic_token_offset).cpu().tolist()[0] | |
| # print(f"len of semantic is {len(semantic_tokens_list)}") | |
| ##对text进行token | |
| text_tokens = self.text_tokenizer(text, truncation=True, max_length=self.max_length) | |
| input_ids = self.text_tokenizer.encode("<|start_glm_token|>") + glm_semantic_tokens_list + self.text_tokenizer.encode("<|end_glm_token|>") \ | |
| + self.text_tokenizer.encode("<|start_semantic_token|>") + semantic_tokens_list + self.text_tokenizer.encode( | |
| "<|end_semantic_token|>") \ | |
| + self.text_tokenizer.encode("<|start_content|>") + text_tokens["input_ids"] + self.text_tokenizer.encode("<|end_content|>") \ | |
| + self.text_tokenizer.encode("<|im_end|>") | |
| attention_mask = [1] * (len(input_ids)) | |
| labels = [-100] * (len(semantic_tokens_list) + 5 + len(glm_semantic_tokens_list)) + text_tokens["input_ids"] + self.text_tokenizer.encode( | |
| "<|end_content|>") + self.text_tokenizer.encode("<|im_end|>") | |
| elif task == "tts-a": | |
| with torch.no_grad(): | |
| semantic_tokens = self.bicodec_tokenizer.tokenize([audio_path])['semantic_tokens'] | |
| global_tokens = self.bicodec_tokenizer.tokenize([ref_audio_path])['global_tokens'] | |
| all_text = "<|start_content|>" + text + "<|end_content|>" | |
| global_tokens_list = (global_tokens + self.global_token_offset).cpu().tolist()[0][0] | |
| text_tokens = self.text_tokenizer(all_text, truncation=True, max_length=self.max_length) | |
| semantic_tokens_list = (semantic_tokens + self.semantic_token_offset).cpu().tolist()[0] | |
| input_ids = self.text_tokenizer.encode("<|start_global_token|>") + global_tokens_list + self.text_tokenizer.encode( | |
| "<|end_global_token|>") + text_tokens["input_ids"] + semantic_tokens_list + self.text_tokenizer.encode("<|im_end|>") | |
| attention_mask = [1] * len(input_ids) | |
| labels = [-100] * (len(text_tokens["input_ids"]) + 2 + len(global_tokens_list)) + semantic_tokens_list + self.text_tokenizer.encode("<|im_end|>") | |
| elif task == "vc": | |
| with torch.no_grad(): | |
| semantic_tokens = self.bicodec_tokenizer.tokenize([audio_path])['semantic_tokens'] | |
| global_tokens = self.bicodec_tokenizer.tokenize([audio_path])['global_tokens'] | |
| # global_tokens, semantic_tokens=self.bicodec_tokenizer.tokenize(audio_path=audio_path) | |
| # new_global_tokens, new_semantic_tokens=self.bicodec_tokenizer.tokenize(vc_audio,ref_audio_path) | |
| new_semantic_tokens = self.bicodec_tokenizer.tokenize([vc_audio])['semantic_tokens'] | |
| new_global_tokens = self.bicodec_tokenizer.tokenize([ref_audio_path])['global_tokens'] | |
| global_tokens_list = (global_tokens + self.global_token_offset).cpu().tolist()[0][0] | |
| semantic_tokens_list = (semantic_tokens + self.semantic_token_offset).cpu().tolist()[0] | |
| new_global_tokens_list = (new_global_tokens + self.global_token_offset).cpu().tolist()[0][0] | |
| new_semantic_tokens_list = (new_semantic_tokens + self.semantic_token_offset).cpu().tolist()[0] | |
| all_str = "<|start_global_token|>" + self.text_tokenizer.decode(new_global_tokens_list) + "<|end_global_token|>" + "<|start_semantic_token|>" + self.text_tokenizer.decode( | |
| semantic_tokens_list) + "<|end_semantic_token|>" + "<|end_content|>" + self.text_tokenizer.decode(new_semantic_tokens_list) + "<|im_end|>" | |
| ##add token and mask | |
| inputs = self.text_tokenizer(all_str) | |
| input_ids = inputs['input_ids'] | |
| attention_mask = inputs['attention_mask'] | |
| labels = [-100] * (5 + len(new_global_tokens_list) + len(semantic_tokens_list)) + new_semantic_tokens_list + self.text_tokenizer.encode("<|im_end|>") | |
| else: | |
| ##默认走stt | |
| with torch.no_grad(): | |
| glm_semantic_tokens = self.glm_tokenizer.extract([audio_path]) | |
| glm_semantic_tokens = torch.as_tensor(glm_semantic_tokens, device="cpu", dtype=torch.long) | |
| semantic_tokens = self.bicodec_tokenizer.tokenize([audio_path])['semantic_tokens'] | |
| glm_semantic_tokens_list = (glm_semantic_tokens+self.glm_semantic_token_offset).cpu().tolist()[0] | |
| semantic_tokens_list = (semantic_tokens+self.semantic_token_offset).cpu().tolist()[0] | |
| # print(f"len of semantic is {len(semantic_tokens_list)}") | |
| ##对text进行token | |
| text_tokens = self.text_tokenizer(text, truncation=True, max_length=self.max_length) | |
| input_ids = self.text_tokenizer.encode("<|start_glm_token|>")+ glm_semantic_tokens_list + self.text_tokenizer.encode("<|end_glm_token|>") \ | |
| + self.text_tokenizer.encode("<|start_semantic_token|>")+ semantic_tokens_list + self.text_tokenizer.encode("<|end_semantic_token|>") \ | |
| + text_tokens["input_ids"] \ | |
| + self.text_tokenizer.encode("<|im_end|>") | |
| attention_mask = [1]*(len(semantic_tokens_list)+4+len(glm_semantic_tokens_list)) +text_tokens["attention_mask"] +[1] | |
| labels = [-100]*(len(semantic_tokens_list)+4+len(glm_semantic_tokens_list))+ text_tokens["input_ids"]+ self.text_tokenizer.encode("<|im_end|>") | |
| return { | |
| "input_ids": input_ids, | |
| "attention_mask": attention_mask, | |
| "labels": labels, | |
| } | |
| def create_tts_collate_fn( | |
| pad_token_id: int, | |
| processor, # ark_processor | |
| max_length: Optional[int]=None,# 传入你想要的截断上限,例如 512 | |
| truncation_side: str = "right" # "right" 或 "left",默认右截断 | |
| ): | |
| """ | |
| 手动填充 + 可选截断的 collate_fn 工厂。 | |
| 参数: | |
| pad_token_id: 用于 input_ids 的 pad 值 | |
| processor: 你的 ark_processor,需提供 .process_example() | |
| max_length: 若提供,则对每个样本在拼批前先截断到该长度 | |
| truncation_side: "right" | "left",决定从哪侧截断 | |
| """ | |
| label_pad_value = -100 | |
| attention_mask_pad_value = 0 | |
| def _truncate_1d(x: torch.Tensor, keep_len: int, side: str) -> torch.Tensor: | |
| if x.numel() <= keep_len: | |
| return x | |
| if side == "right": | |
| return x[:keep_len] | |
| elif side == "left": | |
| return x[-keep_len:] | |
| else: | |
| raise ValueError(f"Unsupported truncation_side: {side}") | |
| def _to_long_tensor(x) -> torch.Tensor: | |
| if isinstance(x, torch.Tensor): | |
| return x.detach().clone().long() | |
| return torch.tensor(x, dtype=torch.long) | |
| def collate_fn(examples: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]: | |
| # 1) 预处理(过滤空样本) | |
| proc = [processor.process_example(ex) for ex in examples if ex] | |
| proc = [d for d in proc if d and ("input_ids" in d) and ("attention_mask" in d) and ("labels" in d)] | |
| if len(proc) == 0: | |
| # 返回空批,避免 DataLoader 崩溃 | |
| return { | |
| "input_ids": torch.empty(0, dtype=torch.long), | |
| "attention_mask": torch.empty(0, dtype=torch.long), | |
| "labels": torch.empty(0, dtype=torch.long), | |
| } | |
| # 2) 样本级截断(如果设置了 max_length) | |
| if max_length is not None: | |
| trimmed = [] | |
| for ex in proc: | |
| ids = _to_long_tensor(ex["input_ids"]) | |
| mask = _to_long_tensor(ex["attention_mask"]) | |
| labs = _to_long_tensor(ex["labels"]) | |
| keep_len = min(max_length, ids.numel()) | |
| ids = _truncate_1d(ids, keep_len, truncation_side) | |
| mask = _truncate_1d(mask, keep_len, truncation_side) | |
| labs = _truncate_1d(labs, keep_len, truncation_side) | |
| trimmed.append({"input_ids": ids, "attention_mask": mask, "labels": labs}) | |
| proc = trimmed | |
| # 3) 计算本批最大长度(截断后再取最大) | |
| max_len_in_batch = max(int(len(ex["input_ids"])) for ex in proc) | |
| # 4) 逐样本右侧 pad 到 batch 最大长度 | |
| padded_input_ids_list = [] | |
| padded_attention_mask_list = [] | |
| padded_labels_list = [] | |
| for ex in proc: | |
| ids = _to_long_tensor(ex["input_ids"]) | |
| mask = _to_long_tensor(ex["attention_mask"]) | |
| labs = _to_long_tensor(ex["labels"]) | |
| need = max_len_in_batch - ids.numel() | |
| if need < 0: | |
| # 极端情况:有人为 max_length=None 时超长样本溢出 | |
| keep_len = max_len_in_batch | |
| ids = _truncate_1d(ids, keep_len, "right") | |
| mask = _truncate_1d(mask, keep_len, "right") | |
| labs = _truncate_1d(labs, keep_len, "right") | |
| need = 0 | |
| pad_dims = (0, need) | |
| ids = F.pad(ids, pad_dims, mode="constant", value=pad_token_id) | |
| mask = F.pad(mask, pad_dims, mode="constant", value=attention_mask_pad_value) | |
| labs = F.pad(labs, pad_dims, mode="constant", value=label_pad_value) | |
| padded_input_ids_list.append(ids) | |
| padded_attention_mask_list.append(mask) | |
| padded_labels_list.append(labs) | |
| # 5) 堆叠成批 | |
| batch = { | |
| "input_ids": torch.stack(padded_input_ids_list, dim=0), | |
| "attention_mask": torch.stack(padded_attention_mask_list, dim=0), | |
| "labels": torch.stack(padded_labels_list, dim=0), | |
| } | |
| return batch | |
| return collate_fn | |
| if __name__ == "__main__": | |
| device = "cuda:0" | |
| bicodec_audio_tokenizer_path = "/data/arki_production/model/SparkAudio/Spark-TTS-0___5B/" | |
| glm_speech_tokenizer_path = "/data/yumu/model/glm-4-voice-tokenizer" | |
| feature_extractor = WhisperFeatureExtractor.from_pretrained(glm_speech_tokenizer_path) | |
| audio_model = WhisperVQEncoder.from_pretrained(glm_speech_tokenizer_path).eval().to(device) | |
| glm_tokenizer = SpeechTokenExtractor(model=audio_model, feature_extractor=feature_extractor, device=device) | |
| text_tokenizer = AutoTokenizer.from_pretrained("/data/yumu/model/ark_audio_v1_0_3_b",trust_remote_code=True) | |
| bicodec_tokenizer = SparkTokenizer(model_path=bicodec_audio_tokenizer_path, device=device) | |
| # 配置项 | |
| DATASET_PATH = "/data/yumu/glm_asr_vllm/test/data/test_meeting.jsonl" | |
| MAX_LENGTH = 4096 | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"将使用设备: {DEVICE}") | |
| # --- 2. 加载流式数据集 --- | |
| print(f"以流式方式加载数据集 '{DATASET_PATH}'...") | |
| streaming_dataset = load_dataset("json", data_files=DATASET_PATH, streaming=True)['train'] | |
| # --- 4. 构建数据处理流水线 (Pipeline) --- | |
| print("正在对数据流进行shuffle,buffer_size=1000...") | |
| shuffled_dataset = streaming_dataset.shuffle(buffer_size=10000, seed=42) | |
| processor = ark_processor( | |
| glm_tokenizer=glm_tokenizer, | |
| bicodec_tokenizer=bicodec_tokenizer, | |
| text_tokenizer=text_tokenizer, | |
| device = DEVICE, | |
| audio_path_name="audio") | |
| collate_fn = create_tts_collate_fn(text_tokenizer.pad_token_id,processor,max_length=4096) | |
| # 创建最终的DataLoader | |
| data_loader = DataLoader( | |
| shuffled_dataset, | |
| batch_size=10, # 根据你的GPU显存和模型大小调整 | |
| collate_fn=collate_fn, | |
| num_workers=0 # DataLoader的worker,负责从打乱后的流中拉取数据 | |
| ) | |
| print("\n--- 高性能流式 DataLoader 演示 ---") | |
| print("将从DataLoader中获取并展示第一个批次的数据:\n") | |
| first_batch = next(iter(data_loader)) | |
| print("成功获取第一个批次!数据已在collate_fn中填充。") | |
| for key, value in first_batch.items(): | |
| if value is not None: | |
| # print(f" - {key}: shape={value.shape}, dtype={value.dtype}") | |
| print(f" - {key}: shape={value.shape}, dtype={value}") |