| import os |
| import random |
| import torch |
| import torchaudio |
| from datasets import load_dataset, Dataset |
| import sys |
| from tqdm import tqdm |
|
|
| sys.path.append('/root/autodl-tmp/CosyVoice') |
| from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2 |
| from cosyvoice.utils.file_utils import load_wav |
|
|
| |
| COMMON_VOICE_LANGUAGE = "en" |
| DATASET_NAME = "gsm8k" |
| OUTPUT_DATASET_PATH = './gsm8k_with_audio' |
| SAMPLE_RATE = 16000 |
|
|
| |
|
|
| def get_random_prompt(common_voice_dataset, sample_rate=SAMPLE_RATE): |
| """ |
| 从 Common Voice 数据集中随机抽取一条语音及对应文本作为 prompt。 |
| """ |
| idx = random.randint(0, len(common_voice_dataset) - 1) |
| sample = common_voice_dataset.select([idx])[0] |
| audio = sample['audio'] |
| waveform = torch.tensor(audio['array'], dtype=torch.float32) |
| sr = audio['sampling_rate'] |
| if sr != sample_rate: |
| resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate) |
| waveform = resampler(waveform) |
| return waveform.unsqueeze(0), sample['raw_text'] |
|
|
| def text_to_audio(query_text, cosyvoice, common_voice_dataset, stream=False): |
| """ |
| 利用 CosyVoice2 模型将输入文本转换为语音,采用随机 prompt 进行 zero-shot 推理。 |
| """ |
| try: |
| prompt_speech, prompt_text = get_random_prompt(common_voice_dataset, sample_rate=SAMPLE_RATE) |
| |
| |
| all_speech = [] |
| for i, j in enumerate(cosyvoice.inference_zero_shot( |
| query_text, |
| prompt_text, |
| prompt_speech, |
| stream=stream, |
| text_frontend=False |
| )): |
| all_speech.append(j['tts_speech']) |
| |
| combined_speech = torch.cat(all_speech, dim=-1) |
| sample_rate_val = cosyvoice.sample_rate |
| return {'audio_tensor': combined_speech, 'sample_rate': sample_rate_val} |
| except Exception as e: |
| print(f"Error converting text to audio: {e}") |
| return None |
|
|
| def process_example(example, cosyvoice, common_voice_dataset, sample_rate=SAMPLE_RATE): |
| """ |
| 针对 gsm8k 数据集中的单个样本进行 TTS 处理。 |
| 假设 gsm8k 数据集中的问题文本字段为 'question', |
| 答案字段为 'answer'。 |
| """ |
| query = example['question'] |
| audio_result = text_to_audio(query, cosyvoice, common_voice_dataset, stream=False) |
| if audio_result is not None: |
| |
| return { |
| 'audio_tensor': audio_result['audio_tensor'], |
| 'sample_rate': audio_result['sample_rate'] |
| } |
| else: |
| return None |
|
|
| |
|
|
| print("Loading Common Voice dataset...") |
| common_voice = load_dataset("facebook/voxpopuli", "en", split='train') |
| print(f"Total Common Voice {COMMON_VOICE_LANGUAGE} samples: {len(common_voice)}") |
|
|
| print("Initializing CosyVoice2 model...") |
| cosyvoice = CosyVoice2( |
| '/root/autodl-tmp/CosyVoice/pretrained_models/CosyVoice2-0.5B', |
| load_jit=True, |
| load_trt=False, |
| fp16=False |
| ) |
|
|
| print("Loading GSM8K dataset...") |
| dataset = load_dataset("openai/gsm8k", 'main') |
|
|
| |
| os.makedirs(OUTPUT_DATASET_PATH, exist_ok=True) |
|
|
| |
| |
| final_dataset_dict = {} |
|
|
| for split_name, split_dataset in dataset.items(): |
| print(f"Processing split: {split_name} with {len(split_dataset)} examples") |
| split_output_dir = os.path.join(OUTPUT_DATASET_PATH, split_name) |
| os.makedirs(split_output_dir, exist_ok=True) |
| |
| |
| progress_file = os.path.join(split_output_dir, "progress.txt") |
| start_index = 0 |
| if os.path.exists(progress_file): |
| try: |
| with open(progress_file, "r") as f: |
| start_index = int(f.read().strip()) |
| print(f"Resuming split '{split_name}' from sample index {start_index}") |
| except Exception as e: |
| print(f"读取进度文件失败:{e}") |
| |
| final_samples = [] |
| for i in tqdm(range(len(split_dataset)), desc=f"Processing {split_name}"): |
| if i < start_index: |
| |
| sample = split_dataset[i] |
| wav_path = os.path.join(split_output_dir, f"{i}.wav") |
| |
| if os.path.exists(wav_path): |
| final_samples.append({ |
| "question_text": sample["question"], |
| "answer": sample["answer"], |
| "audio_filepath": wav_path |
| }) |
| continue |
| |
| sample = split_dataset[i] |
| |
| result = process_example(sample, cosyvoice, common_voice, sample_rate=SAMPLE_RATE) |
| |
| if result is not None: |
| |
| audio_tensor = result['audio_tensor'] |
| if audio_tensor.dim() == 1: |
| audio_tensor = audio_tensor.unsqueeze(0) |
| sample_rate_val = result['sample_rate'] |
| output_wav_path = os.path.join(split_output_dir, f"{i}.wav") |
| try: |
| torchaudio.save(output_wav_path, audio_tensor, sample_rate_val) |
| except Exception as e: |
| print(f"Failed to save wav for sample {i}: {e}") |
| continue |
|
|
| |
| final_samples.append({ |
| "question_text": sample["question"], |
| "answer": sample["answer"], |
| "audio_filepath": output_wav_path |
| }) |
| else: |
| print(f"Sample {i} processing failed, no audio generated.") |
| |
| |
| with open(progress_file, "w") as f: |
| f.write(str(i + 1)) |
| |
| |
| final_dataset_obj = Dataset.from_list(final_samples) |
| final_dataset_save_path = os.path.join(split_output_dir, "final_dataset") |
| final_dataset_obj.save_to_disk(final_dataset_save_path) |
| print(f"Finished processing split: {split_name} with {len(final_samples)} final samples.") |
| final_dataset_dict[split_name] = final_dataset_obj |
|
|
| print("所有分割处理完毕,最终数据集已保存。") |
|
|