GPA_DEMO / data_utils /audio_dataset_ark_audio.py
wanglamao
init
528efee
import os
import re
from models.bicodec_tokenizer.spark_tokenizer import SparkTokenizer
from models.glm_speech_tokenizer.speech_token_extractor import SpeechTokenExtractor
from models.glm_speech_tokenizer.modeling_whisper import WhisperVQEncoder
from transformers import PreTrainedTokenizer,AutoTokenizer,WhisperFeatureExtractor
import torch
import torch.nn.functional as F
import logging
from typing import List, Dict, Any, Literal, Optional, Union
from datasets import load_dataset
from torch.utils.data import DataLoader
def has_punctuation(text: str) -> bool:
# 包含中英文符号
pattern = r"[,。!?;:()“”‘’、,.!?;:()\[\]{}\"']"
return bool(re.search(pattern, text))
ALL_TASKS = ["stt", "tts-a", "vc"]
class ark_infer_processor:
def __init__(
self,
glm_tokenizer: SpeechTokenExtractor,
bicodec_tokenizer: SparkTokenizer,
text_tokenizer: PreTrainedTokenizer,
max_length: int = 512,
glm_semantic_token_offset: int = 151727,
semantic_token_offset: int = 172207,
global_token_offset: int = 168111,
audio_path_name: str = "audio",
device: str = "cpu",
):
self.glm_tokenizer = glm_tokenizer
self.bicodec_tokenizer = bicodec_tokenizer
self.text_tokenizer = text_tokenizer
self.max_length = max_length
self.glm_semantic_token_offset = glm_semantic_token_offset
self.semantic_token_offset = semantic_token_offset
self.global_token_offset = global_token_offset
self.device = device
self.audio_path_name = audio_path_name
def _process_example_stt(self, audio_path: str):
##target 音频
with torch.no_grad():
glm_semantic_tokens = self.glm_tokenizer.extract([audio_path])
glm_semantic_tokens = torch.as_tensor(glm_semantic_tokens, device="cpu", dtype=torch.long)
semantic_tokens = self.bicodec_tokenizer.tokenize([audio_path])['semantic_tokens']
glm_semantic_tokens_list = (
(glm_semantic_tokens + self.glm_semantic_token_offset).cpu().tolist()[0]
)
semantic_tokens_list = (
(semantic_tokens + self.semantic_token_offset).cpu().tolist()[0]
)
input_ids = (
self.text_tokenizer.encode("<|start_glm_token|>")
+ glm_semantic_tokens_list
+ self.text_tokenizer.encode("<|end_glm_token|>")
+ self.text_tokenizer.encode("<|start_semantic_token|>")
+ semantic_tokens_list
+ self.text_tokenizer.encode("<|end_semantic_token|>")
+ self.text_tokenizer.encode("<|start_content|>")
)
attention_mask = [1] * (len(input_ids))
return input_ids, attention_mask
def _process_example_tts_a(self, text: str, ref_audio_path: str):
with torch.no_grad():
global_tokens = self.bicodec_tokenizer.tokenize([ref_audio_path])['global_tokens']
all_text = "<|start_content|>" + text + "<|end_content|>"
global_tokens_list = (
(global_tokens + self.global_token_offset).cpu().tolist()[0][0]
)
text_tokens = self.text_tokenizer(
all_text, truncation=True, max_length=self.max_length
)
input_ids = (
self.text_tokenizer.encode("<|start_global_token|>")
+ global_tokens_list
+ self.text_tokenizer.encode("<|end_global_token|>")
+ text_tokens["input_ids"]
)
attention_mask = [1] * len(input_ids)
return input_ids, attention_mask
def _process_example_vc(self, audio_path: str, ref_audio_path: str):
with torch.no_grad():
semantic_tokens = self.bicodec_tokenizer.tokenize([audio_path])['semantic_tokens']
new_global_tokens = self.bicodec_tokenizer.tokenize([ref_audio_path])['global_tokens']
semantic_tokens_list = (
(semantic_tokens + self.semantic_token_offset).cpu().tolist()[0]
)
new_global_tokens_list = (
(new_global_tokens + self.global_token_offset).cpu().tolist()[0][0]
)
all_str = (
"<|start_global_token|>"
+ self.text_tokenizer.decode(new_global_tokens_list)
+ "<|end_global_token|>"
+ "<|start_semantic_token|>"
+ self.text_tokenizer.decode(semantic_tokens_list)
+ "<|end_semantic_token|>"
+ "<|end_content|>"
)
inputs = self.text_tokenizer(all_str)
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
return input_ids, attention_mask
def process_input(
self,
task: Literal["stt", "tts-a", "vc"],
audio_path: str | None = None,
ref_audio_path: str | None = None,
text: str | None = None,
):
"""加载指定音频、特征并根据任务类型返回 token 化结果。"""
if task == "stt":
assert audio_path is not None
input_ids, attention_mask = self._process_example_stt(audio_path)
elif task == "tts-a":
assert ref_audio_path is not None and text is not None
input_ids, attention_mask = self._process_example_tts_a(
text, ref_audio_path
)
elif task == "vc":
assert audio_path is not None and ref_audio_path is not None
input_ids, attention_mask = self._process_example_vc(
audio_path, ref_audio_path
)
else:
raise ValueError(
f"Unsupported task: {task}, all supported tasks: {ALL_TASKS}"
)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
}
class ark_processor:
def __init__(self,
glm_tokenizer: SpeechTokenExtractor,
bicodec_tokenizer: SparkTokenizer,
text_tokenizer:PreTrainedTokenizer,
max_length:int = 512,
glm_semantic_token_offset:int = 151727,
semantic_token_offset: int =172207,
global_token_offset: int =168111,
audio_path_name:str = "audio",
device:str ='cpu'):
self.glm_tokenizer = glm_tokenizer
self.bicodec_tokenizer = bicodec_tokenizer
self.text_tokenizer = text_tokenizer
self.max_length = max_length
self.glm_semantic_token_offset =glm_semantic_token_offset
self.semantic_token_offset=semantic_token_offset
self.global_token_offset=global_token_offset
self.device = device
self.audio_path_name =audio_path_name
def process_example(self, example: Dict[str, Any]):
"""
这个函数由多个CPU进程并行执行。
它负责加载、重采样和对单个样本进行特征提取/分词。
"""
task = example.get("task", "stt")
audio_path = example.get(self.audio_path_name, "")
ref_audio_path = example.get("ref_audio", "")
vc_audio = example.get("vc_audio", "")
text = example.get("text", "")
if task == "stt":
##target 音频
with torch.no_grad():
glm_semantic_tokens = self.glm_tokenizer.extract([audio_path])
glm_semantic_tokens = torch.as_tensor(glm_semantic_tokens, device="cpu", dtype=torch.long)
semantic_tokens = self.bicodec_tokenizer.tokenize([audio_path])['semantic_tokens']
glm_semantic_tokens_list = (glm_semantic_tokens + self.glm_semantic_token_offset).cpu().tolist()[0]
semantic_tokens_list = (semantic_tokens + self.semantic_token_offset).cpu().tolist()[0]
# print(f"len of semantic is {len(semantic_tokens_list)}")
##对text进行token
text_tokens = self.text_tokenizer(text, truncation=True, max_length=self.max_length)
input_ids = self.text_tokenizer.encode("<|start_glm_token|>") + glm_semantic_tokens_list + self.text_tokenizer.encode("<|end_glm_token|>") \
+ self.text_tokenizer.encode("<|start_semantic_token|>") + semantic_tokens_list + self.text_tokenizer.encode(
"<|end_semantic_token|>") \
+ self.text_tokenizer.encode("<|start_content|>") + text_tokens["input_ids"] + self.text_tokenizer.encode("<|end_content|>") \
+ self.text_tokenizer.encode("<|im_end|>")
attention_mask = [1] * (len(input_ids))
labels = [-100] * (len(semantic_tokens_list) + 5 + len(glm_semantic_tokens_list)) + text_tokens["input_ids"] + self.text_tokenizer.encode(
"<|end_content|>") + self.text_tokenizer.encode("<|im_end|>")
elif task == "tts-a":
with torch.no_grad():
semantic_tokens = self.bicodec_tokenizer.tokenize([audio_path])['semantic_tokens']
global_tokens = self.bicodec_tokenizer.tokenize([ref_audio_path])['global_tokens']
all_text = "<|start_content|>" + text + "<|end_content|>"
global_tokens_list = (global_tokens + self.global_token_offset).cpu().tolist()[0][0]
text_tokens = self.text_tokenizer(all_text, truncation=True, max_length=self.max_length)
semantic_tokens_list = (semantic_tokens + self.semantic_token_offset).cpu().tolist()[0]
input_ids = self.text_tokenizer.encode("<|start_global_token|>") + global_tokens_list + self.text_tokenizer.encode(
"<|end_global_token|>") + text_tokens["input_ids"] + semantic_tokens_list + self.text_tokenizer.encode("<|im_end|>")
attention_mask = [1] * len(input_ids)
labels = [-100] * (len(text_tokens["input_ids"]) + 2 + len(global_tokens_list)) + semantic_tokens_list + self.text_tokenizer.encode("<|im_end|>")
elif task == "vc":
with torch.no_grad():
semantic_tokens = self.bicodec_tokenizer.tokenize([audio_path])['semantic_tokens']
global_tokens = self.bicodec_tokenizer.tokenize([audio_path])['global_tokens']
# global_tokens, semantic_tokens=self.bicodec_tokenizer.tokenize(audio_path=audio_path)
# new_global_tokens, new_semantic_tokens=self.bicodec_tokenizer.tokenize(vc_audio,ref_audio_path)
new_semantic_tokens = self.bicodec_tokenizer.tokenize([vc_audio])['semantic_tokens']
new_global_tokens = self.bicodec_tokenizer.tokenize([ref_audio_path])['global_tokens']
global_tokens_list = (global_tokens + self.global_token_offset).cpu().tolist()[0][0]
semantic_tokens_list = (semantic_tokens + self.semantic_token_offset).cpu().tolist()[0]
new_global_tokens_list = (new_global_tokens + self.global_token_offset).cpu().tolist()[0][0]
new_semantic_tokens_list = (new_semantic_tokens + self.semantic_token_offset).cpu().tolist()[0]
all_str = "<|start_global_token|>" + self.text_tokenizer.decode(new_global_tokens_list) + "<|end_global_token|>" + "<|start_semantic_token|>" + self.text_tokenizer.decode(
semantic_tokens_list) + "<|end_semantic_token|>" + "<|end_content|>" + self.text_tokenizer.decode(new_semantic_tokens_list) + "<|im_end|>"
##add token and mask
inputs = self.text_tokenizer(all_str)
input_ids = inputs['input_ids']
attention_mask = inputs['attention_mask']
labels = [-100] * (5 + len(new_global_tokens_list) + len(semantic_tokens_list)) + new_semantic_tokens_list + self.text_tokenizer.encode("<|im_end|>")
else:
##默认走stt
with torch.no_grad():
glm_semantic_tokens = self.glm_tokenizer.extract([audio_path])
glm_semantic_tokens = torch.as_tensor(glm_semantic_tokens, device="cpu", dtype=torch.long)
semantic_tokens = self.bicodec_tokenizer.tokenize([audio_path])['semantic_tokens']
glm_semantic_tokens_list = (glm_semantic_tokens+self.glm_semantic_token_offset).cpu().tolist()[0]
semantic_tokens_list = (semantic_tokens+self.semantic_token_offset).cpu().tolist()[0]
# print(f"len of semantic is {len(semantic_tokens_list)}")
##对text进行token
text_tokens = self.text_tokenizer(text, truncation=True, max_length=self.max_length)
input_ids = self.text_tokenizer.encode("<|start_glm_token|>")+ glm_semantic_tokens_list + self.text_tokenizer.encode("<|end_glm_token|>") \
+ self.text_tokenizer.encode("<|start_semantic_token|>")+ semantic_tokens_list + self.text_tokenizer.encode("<|end_semantic_token|>") \
+ text_tokens["input_ids"] \
+ self.text_tokenizer.encode("<|im_end|>")
attention_mask = [1]*(len(semantic_tokens_list)+4+len(glm_semantic_tokens_list)) +text_tokens["attention_mask"] +[1]
labels = [-100]*(len(semantic_tokens_list)+4+len(glm_semantic_tokens_list))+ text_tokens["input_ids"]+ self.text_tokenizer.encode("<|im_end|>")
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
}
def create_tts_collate_fn(
pad_token_id: int,
processor, # ark_processor
max_length: Optional[int]=None,# 传入你想要的截断上限,例如 512
truncation_side: str = "right" # "right" 或 "left",默认右截断
):
"""
手动填充 + 可选截断的 collate_fn 工厂。
参数:
pad_token_id: 用于 input_ids 的 pad 值
processor: 你的 ark_processor,需提供 .process_example()
max_length: 若提供,则对每个样本在拼批前先截断到该长度
truncation_side: "right" | "left",决定从哪侧截断
"""
label_pad_value = -100
attention_mask_pad_value = 0
def _truncate_1d(x: torch.Tensor, keep_len: int, side: str) -> torch.Tensor:
if x.numel() <= keep_len:
return x
if side == "right":
return x[:keep_len]
elif side == "left":
return x[-keep_len:]
else:
raise ValueError(f"Unsupported truncation_side: {side}")
def _to_long_tensor(x) -> torch.Tensor:
if isinstance(x, torch.Tensor):
return x.detach().clone().long()
return torch.tensor(x, dtype=torch.long)
def collate_fn(examples: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
# 1) 预处理(过滤空样本)
proc = [processor.process_example(ex) for ex in examples if ex]
proc = [d for d in proc if d and ("input_ids" in d) and ("attention_mask" in d) and ("labels" in d)]
if len(proc) == 0:
# 返回空批,避免 DataLoader 崩溃
return {
"input_ids": torch.empty(0, dtype=torch.long),
"attention_mask": torch.empty(0, dtype=torch.long),
"labels": torch.empty(0, dtype=torch.long),
}
# 2) 样本级截断(如果设置了 max_length)
if max_length is not None:
trimmed = []
for ex in proc:
ids = _to_long_tensor(ex["input_ids"])
mask = _to_long_tensor(ex["attention_mask"])
labs = _to_long_tensor(ex["labels"])
keep_len = min(max_length, ids.numel())
ids = _truncate_1d(ids, keep_len, truncation_side)
mask = _truncate_1d(mask, keep_len, truncation_side)
labs = _truncate_1d(labs, keep_len, truncation_side)
trimmed.append({"input_ids": ids, "attention_mask": mask, "labels": labs})
proc = trimmed
# 3) 计算本批最大长度(截断后再取最大)
max_len_in_batch = max(int(len(ex["input_ids"])) for ex in proc)
# 4) 逐样本右侧 pad 到 batch 最大长度
padded_input_ids_list = []
padded_attention_mask_list = []
padded_labels_list = []
for ex in proc:
ids = _to_long_tensor(ex["input_ids"])
mask = _to_long_tensor(ex["attention_mask"])
labs = _to_long_tensor(ex["labels"])
need = max_len_in_batch - ids.numel()
if need < 0:
# 极端情况:有人为 max_length=None 时超长样本溢出
keep_len = max_len_in_batch
ids = _truncate_1d(ids, keep_len, "right")
mask = _truncate_1d(mask, keep_len, "right")
labs = _truncate_1d(labs, keep_len, "right")
need = 0
pad_dims = (0, need)
ids = F.pad(ids, pad_dims, mode="constant", value=pad_token_id)
mask = F.pad(mask, pad_dims, mode="constant", value=attention_mask_pad_value)
labs = F.pad(labs, pad_dims, mode="constant", value=label_pad_value)
padded_input_ids_list.append(ids)
padded_attention_mask_list.append(mask)
padded_labels_list.append(labs)
# 5) 堆叠成批
batch = {
"input_ids": torch.stack(padded_input_ids_list, dim=0),
"attention_mask": torch.stack(padded_attention_mask_list, dim=0),
"labels": torch.stack(padded_labels_list, dim=0),
}
return batch
return collate_fn
if __name__ == "__main__":
device = "cuda:0"
bicodec_audio_tokenizer_path = "/data/arki_production/model/SparkAudio/Spark-TTS-0___5B/"
glm_speech_tokenizer_path = "/data/yumu/model/glm-4-voice-tokenizer"
feature_extractor = WhisperFeatureExtractor.from_pretrained(glm_speech_tokenizer_path)
audio_model = WhisperVQEncoder.from_pretrained(glm_speech_tokenizer_path).eval().to(device)
glm_tokenizer = SpeechTokenExtractor(model=audio_model, feature_extractor=feature_extractor, device=device)
text_tokenizer = AutoTokenizer.from_pretrained("/data/yumu/model/ark_audio_v1_0_3_b",trust_remote_code=True)
bicodec_tokenizer = SparkTokenizer(model_path=bicodec_audio_tokenizer_path, device=device)
# 配置项
DATASET_PATH = "/data/yumu/glm_asr_vllm/test/data/test_meeting.jsonl"
MAX_LENGTH = 4096
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"将使用设备: {DEVICE}")
# --- 2. 加载流式数据集 ---
print(f"以流式方式加载数据集 '{DATASET_PATH}'...")
streaming_dataset = load_dataset("json", data_files=DATASET_PATH, streaming=True)['train']
# --- 4. 构建数据处理流水线 (Pipeline) ---
print("正在对数据流进行shuffle,buffer_size=1000...")
shuffled_dataset = streaming_dataset.shuffle(buffer_size=10000, seed=42)
processor = ark_processor(
glm_tokenizer=glm_tokenizer,
bicodec_tokenizer=bicodec_tokenizer,
text_tokenizer=text_tokenizer,
device = DEVICE,
audio_path_name="audio")
collate_fn = create_tts_collate_fn(text_tokenizer.pad_token_id,processor,max_length=4096)
# 创建最终的DataLoader
data_loader = DataLoader(
shuffled_dataset,
batch_size=10, # 根据你的GPU显存和模型大小调整
collate_fn=collate_fn,
num_workers=0 # DataLoader的worker,负责从打乱后的流中拉取数据
)
print("\n--- 高性能流式 DataLoader 演示 ---")
print("将从DataLoader中获取并展示第一个批次的数据:\n")
first_batch = next(iter(data_loader))
print("成功获取第一个批次!数据已在collate_fn中填充。")
for key, value in first_batch.items():
if value is not None:
# print(f" - {key}: shape={value.shape}, dtype={value.dtype}")
print(f" - {key}: shape={value.shape}, dtype={value}")