File size: 6,602 Bytes
19891ba | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 | import os
import random
import torch
import torchaudio
from datasets import load_dataset, Dataset
import sys
from tqdm import tqdm
sys.path.append('/root/autodl-tmp/CosyVoice')
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
from cosyvoice.utils.file_utils import load_wav
# ------------------------
# 配置参数
# ------------------------
COMMON_VOICE_LANGUAGE = "en"
DATASET_NAME = "commonsense_qa"
OUTPUT_DATASET_PATH = './commonsense_qa_with_audio' # 输出目录
SAMPLE_RATE = 16000
# ------------------------
# 辅助函数
# ------------------------
def get_random_prompt(common_voice_dataset, sample_rate=SAMPLE_RATE):
"""
从 VoxPopuli (此处替代原 Common Voice) 数据集中随机抽取一条语音及对应文本作为 prompt。
"""
idx = random.randint(0, len(common_voice_dataset) - 1)
sample = common_voice_dataset.select([idx])[0]
audio = sample['audio']
waveform = torch.tensor(audio['array'], dtype=torch.float32)
sr = audio['sampling_rate']
if sr != sample_rate:
resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate)
waveform = resampler(waveform)
return waveform.unsqueeze(0), sample['raw_text']
def text_to_audio(query_text, cosyvoice, common_voice_dataset, stream=False):
"""
利用 CosyVoice2 模型将输入文本转换为语音,采用随机 prompt 进行零样本推理。
"""
try:
prompt_speech, prompt_text = get_random_prompt(common_voice_dataset, sample_rate=SAMPLE_RATE)
all_speech = []
for i, j in enumerate(cosyvoice.inference_zero_shot(
query_text,
prompt_text,
prompt_speech,
stream=stream,
text_frontend=False
)):
all_speech.append(j['tts_speech'])
# 将所有生成的语音片段拼接在一起
combined_speech = torch.cat(all_speech, dim=-1)
sample_rate_val = cosyvoice.sample_rate
return {
'audio_tensor': combined_speech,
'sample_rate': sample_rate_val
}
except Exception as e:
print(f"Error converting text to audio: {e}")
return None
def process_example(example, cosyvoice, common_voice_dataset, sample_rate=SAMPLE_RATE):
"""
针对 Commonsense QA 数据集中的单个样本进行 TTS 处理。
在此示例中,仅对 sample['question'] 字段执行 TTS。
"""
query = example['question']
audio_result = text_to_audio(query, cosyvoice, common_voice_dataset, stream=False)
if audio_result is not None:
return {
'audio_tensor': audio_result['audio_tensor'],
'sample_rate': audio_result['sample_rate']
}
else:
return None
# ------------------------
# 数据加载与模型初始化
# ------------------------
print("Loading VoxPopuli (as Common Voice) dataset...")
common_voice = load_dataset("facebook/voxpopuli", "en", split='train')
print(f"Total VoxPopuli {COMMON_VOICE_LANGUAGE} samples: {len(common_voice)}")
print("Initializing CosyVoice2 model...")
cosyvoice = CosyVoice2(
'/root/autodl-tmp/CosyVoice/pretrained_models/CosyVoice2-0.5B', # 替换为实际模型路径
load_jit=True,
load_trt=False,
fp16=False
)
print("Loading Commonsense QA dataset...")
dataset = load_dataset("tau/commonsense_qa")
# 如果只想处理 train,可写成 dataset = load_dataset("tau/commonsense_qa", split="train")
# 创建输出目录
os.makedirs(OUTPUT_DATASET_PATH, exist_ok=True)
# ------------------------
# 主处理循环
# ------------------------
final_dataset_dict = {} # 存放各 split 最终处理后的数据
for split_name, split_dataset in dataset.items():
print(f"Processing split: {split_name} with {len(split_dataset)} examples")
split_output_dir = os.path.join(OUTPUT_DATASET_PATH, split_name)
os.makedirs(split_output_dir, exist_ok=True)
# 用于断点续跑的进度记录
progress_file = os.path.join(split_output_dir, "progress.txt")
start_index = 0
if os.path.exists(progress_file):
try:
with open(progress_file, "r") as f:
start_index = int(f.read().strip())
print(f"Resuming split '{split_name}' from sample index {start_index}")
except Exception as e:
print(f"读取进度文件失败:{e}")
final_samples = []
# 遍历处理每条样本
for i in tqdm(range(len(split_dataset)), desc=f"Processing {split_name}"):
# 如果已处理过,就直接跳过并仅把已存在文件的元信息记入 final_samples
if i < start_index:
sample = split_dataset[i]
wav_path = os.path.join(split_output_dir, f"{i}.wav")
if os.path.exists(wav_path):
# 保留所有原始字段 + 音频路径
sample_dict = {k: sample[k] for k in sample.keys()}
sample_dict["audio_filepath"] = wav_path
final_samples.append(sample_dict)
continue
sample = split_dataset[i]
result = process_example(sample, cosyvoice, common_voice, sample_rate=SAMPLE_RATE)
if result is not None:
audio_tensor = result['audio_tensor']
if audio_tensor.dim() == 1:
audio_tensor = audio_tensor.unsqueeze(0)
sample_rate_val = result['sample_rate']
output_wav_path = os.path.join(split_output_dir, f"{i}.wav")
try:
torchaudio.save(output_wav_path, audio_tensor, sample_rate_val)
except Exception as e:
print(f"Failed to save wav for sample {i}: {e}")
continue
# 保留所有原始字段 + 生成的音频路径
sample_dict = {k: sample[k] for k in sample.keys()}
sample_dict["audio_filepath"] = output_wav_path
final_samples.append(sample_dict)
else:
print(f"Sample {i} processing failed, no audio generated.")
# 更新进度记录
with open(progress_file, "w") as f:
f.write(str(i + 1))
# 生成 Hugging Face Dataset 并落盘
final_dataset_obj = Dataset.from_list(final_samples)
final_dataset_save_path = os.path.join(split_output_dir, "final_dataset")
final_dataset_obj.save_to_disk(final_dataset_save_path)
print(f"Finished processing split: {split_name} with {len(final_samples)} final samples.")
final_dataset_dict[split_name] = final_dataset_obj
print("所有分割处理完毕,最终数据集已保存。")
|