grsdfdf / r1-a /dataset /gsm8k.py
1f's picture
Add files using upload-large-folder tool
19891ba verified
import os
import random
import torch
import torchaudio
from datasets import load_dataset, Dataset
import sys
from tqdm import tqdm
sys.path.append('/root/autodl-tmp/CosyVoice')
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
from cosyvoice.utils.file_utils import load_wav
# 配置参数
COMMON_VOICE_LANGUAGE = "en"
DATASET_NAME = "gsm8k" # 使用 gsm8k 数据集
OUTPUT_DATASET_PATH = './gsm8k_with_audio'
SAMPLE_RATE = 16000
# --- 辅助函数 ---
def get_random_prompt(common_voice_dataset, sample_rate=SAMPLE_RATE):
"""
从 Common Voice 数据集中随机抽取一条语音及对应文本作为 prompt。
"""
idx = random.randint(0, len(common_voice_dataset) - 1)
sample = common_voice_dataset.select([idx])[0]
audio = sample['audio']
waveform = torch.tensor(audio['array'], dtype=torch.float32)
sr = audio['sampling_rate']
if sr != sample_rate:
resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate)
waveform = resampler(waveform)
return waveform.unsqueeze(0), sample['raw_text']
def text_to_audio(query_text, cosyvoice, common_voice_dataset, stream=False):
"""
利用 CosyVoice2 模型将输入文本转换为语音,采用随机 prompt 进行 zero-shot 推理。
"""
try:
prompt_speech, prompt_text = get_random_prompt(common_voice_dataset, sample_rate=SAMPLE_RATE)
# 可选:保存 prompt.wav 进行调试
# torchaudio.save('prompt.wav', prompt_speech, SAMPLE_RATE)
all_speech = []
for i, j in enumerate(cosyvoice.inference_zero_shot(
query_text,
prompt_text,
prompt_speech,
stream=stream,
text_frontend=False
)):
all_speech.append(j['tts_speech'])
# 合并所有生成的语音片段为一个长 tensor
combined_speech = torch.cat(all_speech, dim=-1)
sample_rate_val = cosyvoice.sample_rate
return {'audio_tensor': combined_speech, 'sample_rate': sample_rate_val}
except Exception as e:
print(f"Error converting text to audio: {e}")
return None
def process_example(example, cosyvoice, common_voice_dataset, sample_rate=SAMPLE_RATE):
"""
针对 gsm8k 数据集中的单个样本进行 TTS 处理。
假设 gsm8k 数据集中的问题文本字段为 'question',
答案字段为 'answer'。
"""
query = example['question']
audio_result = text_to_audio(query, cosyvoice, common_voice_dataset, stream=False)
if audio_result is not None:
# 返回生成的音频 tensor 及采样率
return {
'audio_tensor': audio_result['audio_tensor'],
'sample_rate': audio_result['sample_rate']
}
else:
return None
# --- 数据加载与模型初始化 ---
print("Loading Common Voice dataset...")
common_voice = load_dataset("facebook/voxpopuli", "en", split='train')
print(f"Total Common Voice {COMMON_VOICE_LANGUAGE} samples: {len(common_voice)}")
print("Initializing CosyVoice2 model...")
cosyvoice = CosyVoice2(
'/root/autodl-tmp/CosyVoice/pretrained_models/CosyVoice2-0.5B', # 替换为实际的模型路径
load_jit=True,
load_trt=False,
fp16=False
)
print("Loading GSM8K dataset...")
dataset = load_dataset("openai/gsm8k", 'main')
# 确保输出总目录存在
os.makedirs(OUTPUT_DATASET_PATH, exist_ok=True)
# --- 主处理循环 ---
# 对每个 split 分别处理,每个样本处理后保存 .wav 文件和记录最终数据集信息
final_dataset_dict = {} # 用于保存最终数据集的每个 split
for split_name, split_dataset in dataset.items():
print(f"Processing split: {split_name} with {len(split_dataset)} examples")
split_output_dir = os.path.join(OUTPUT_DATASET_PATH, split_name)
os.makedirs(split_output_dir, exist_ok=True)
# 用于断点续转的进度记录
progress_file = os.path.join(split_output_dir, "progress.txt")
start_index = 0
if os.path.exists(progress_file):
try:
with open(progress_file, "r") as f:
start_index = int(f.read().strip())
print(f"Resuming split '{split_name}' from sample index {start_index}")
except Exception as e:
print(f"读取进度文件失败:{e}")
final_samples = [] # 用于存储最终数据集样本信息
for i in tqdm(range(len(split_dataset)), desc=f"Processing {split_name}"):
if i < start_index:
# 如果样本已处理,则加载对应的 wav 文件路径(假设之前已经生成)并加入最终数据集
sample = split_dataset[i]
wav_path = os.path.join(split_output_dir, f"{i}.wav")
# 仅当文件存在时才加入最终数据集
if os.path.exists(wav_path):
final_samples.append({
"question_text": sample["question"],
"answer": sample["answer"],
"audio_filepath": wav_path
})
continue
sample = split_dataset[i]
# 处理 TTS 转换
result = process_example(sample, cosyvoice, common_voice, sample_rate=SAMPLE_RATE)
if result is not None:
# 确保 audio tensor shape 为 (channels, samples)
audio_tensor = result['audio_tensor']
if audio_tensor.dim() == 1:
audio_tensor = audio_tensor.unsqueeze(0)
sample_rate_val = result['sample_rate']
output_wav_path = os.path.join(split_output_dir, f"{i}.wav")
try:
torchaudio.save(output_wav_path, audio_tensor, sample_rate_val)
except Exception as e:
print(f"Failed to save wav for sample {i}: {e}")
continue
# 将转换后的样本信息保存到最终数据集中
final_samples.append({
"question_text": sample["question"],
"answer": sample["answer"],
"audio_filepath": output_wav_path
})
else:
print(f"Sample {i} processing failed, no audio generated.")
# 更新进度记录
with open(progress_file, "w") as f:
f.write(str(i + 1))
# 将当前 split 的最终数据集保存为 Hugging Face Dataset,并存盘
final_dataset_obj = Dataset.from_list(final_samples)
final_dataset_save_path = os.path.join(split_output_dir, "final_dataset")
final_dataset_obj.save_to_disk(final_dataset_save_path)
print(f"Finished processing split: {split_name} with {len(final_samples)} final samples.")
final_dataset_dict[split_name] = final_dataset_obj
print("所有分割处理完毕,最终数据集已保存。")