|
|
import os |
|
|
import random |
|
|
import torch |
|
|
import torchaudio |
|
|
from datasets import load_dataset, Dataset |
|
|
import sys |
|
|
from tqdm import tqdm |
|
|
|
|
|
sys.path.append('/root/autodl-tmp/CosyVoice') |
|
|
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2 |
|
|
from cosyvoice.utils.file_utils import load_wav |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
COMMON_VOICE_LANGUAGE = "en" |
|
|
DATASET_NAME = "commonsense_qa" |
|
|
OUTPUT_DATASET_PATH = './commonsense_qa_with_audio' |
|
|
SAMPLE_RATE = 16000 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_random_prompt(common_voice_dataset, sample_rate=SAMPLE_RATE): |
|
|
""" |
|
|
从 VoxPopuli (此处替代原 Common Voice) 数据集中随机抽取一条语音及对应文本作为 prompt。 |
|
|
""" |
|
|
idx = random.randint(0, len(common_voice_dataset) - 1) |
|
|
sample = common_voice_dataset.select([idx])[0] |
|
|
audio = sample['audio'] |
|
|
waveform = torch.tensor(audio['array'], dtype=torch.float32) |
|
|
sr = audio['sampling_rate'] |
|
|
if sr != sample_rate: |
|
|
resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate) |
|
|
waveform = resampler(waveform) |
|
|
return waveform.unsqueeze(0), sample['raw_text'] |
|
|
|
|
|
def text_to_audio(query_text, cosyvoice, common_voice_dataset, stream=False): |
|
|
""" |
|
|
利用 CosyVoice2 模型将输入文本转换为语音,采用随机 prompt 进行零样本推理。 |
|
|
""" |
|
|
try: |
|
|
prompt_speech, prompt_text = get_random_prompt(common_voice_dataset, sample_rate=SAMPLE_RATE) |
|
|
|
|
|
all_speech = [] |
|
|
for i, j in enumerate(cosyvoice.inference_zero_shot( |
|
|
query_text, |
|
|
prompt_text, |
|
|
prompt_speech, |
|
|
stream=stream, |
|
|
text_frontend=False |
|
|
)): |
|
|
all_speech.append(j['tts_speech']) |
|
|
|
|
|
|
|
|
combined_speech = torch.cat(all_speech, dim=-1) |
|
|
sample_rate_val = cosyvoice.sample_rate |
|
|
|
|
|
return { |
|
|
'audio_tensor': combined_speech, |
|
|
'sample_rate': sample_rate_val |
|
|
} |
|
|
except Exception as e: |
|
|
print(f"Error converting text to audio: {e}") |
|
|
return None |
|
|
|
|
|
def process_example(example, cosyvoice, common_voice_dataset, sample_rate=SAMPLE_RATE): |
|
|
""" |
|
|
针对 Commonsense QA 数据集中的单个样本进行 TTS 处理。 |
|
|
在此示例中,仅对 sample['question'] 字段执行 TTS。 |
|
|
""" |
|
|
query = example['question'] |
|
|
audio_result = text_to_audio(query, cosyvoice, common_voice_dataset, stream=False) |
|
|
if audio_result is not None: |
|
|
return { |
|
|
'audio_tensor': audio_result['audio_tensor'], |
|
|
'sample_rate': audio_result['sample_rate'] |
|
|
} |
|
|
else: |
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("Loading VoxPopuli (as Common Voice) dataset...") |
|
|
common_voice = load_dataset("facebook/voxpopuli", "en", split='train') |
|
|
print(f"Total VoxPopuli {COMMON_VOICE_LANGUAGE} samples: {len(common_voice)}") |
|
|
|
|
|
print("Initializing CosyVoice2 model...") |
|
|
cosyvoice = CosyVoice2( |
|
|
'/root/autodl-tmp/CosyVoice/pretrained_models/CosyVoice2-0.5B', |
|
|
load_jit=True, |
|
|
load_trt=False, |
|
|
fp16=False |
|
|
) |
|
|
|
|
|
print("Loading Commonsense QA dataset...") |
|
|
dataset = load_dataset("tau/commonsense_qa") |
|
|
|
|
|
|
|
|
|
|
|
os.makedirs(OUTPUT_DATASET_PATH, exist_ok=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
final_dataset_dict = {} |
|
|
|
|
|
for split_name, split_dataset in dataset.items(): |
|
|
print(f"Processing split: {split_name} with {len(split_dataset)} examples") |
|
|
split_output_dir = os.path.join(OUTPUT_DATASET_PATH, split_name) |
|
|
os.makedirs(split_output_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
progress_file = os.path.join(split_output_dir, "progress.txt") |
|
|
start_index = 0 |
|
|
if os.path.exists(progress_file): |
|
|
try: |
|
|
with open(progress_file, "r") as f: |
|
|
start_index = int(f.read().strip()) |
|
|
print(f"Resuming split '{split_name}' from sample index {start_index}") |
|
|
except Exception as e: |
|
|
print(f"读取进度文件失败:{e}") |
|
|
|
|
|
final_samples = [] |
|
|
|
|
|
|
|
|
for i in tqdm(range(len(split_dataset)), desc=f"Processing {split_name}"): |
|
|
|
|
|
if i < start_index: |
|
|
sample = split_dataset[i] |
|
|
wav_path = os.path.join(split_output_dir, f"{i}.wav") |
|
|
if os.path.exists(wav_path): |
|
|
|
|
|
sample_dict = {k: sample[k] for k in sample.keys()} |
|
|
sample_dict["audio_filepath"] = wav_path |
|
|
final_samples.append(sample_dict) |
|
|
continue |
|
|
|
|
|
sample = split_dataset[i] |
|
|
result = process_example(sample, cosyvoice, common_voice, sample_rate=SAMPLE_RATE) |
|
|
|
|
|
if result is not None: |
|
|
audio_tensor = result['audio_tensor'] |
|
|
if audio_tensor.dim() == 1: |
|
|
audio_tensor = audio_tensor.unsqueeze(0) |
|
|
sample_rate_val = result['sample_rate'] |
|
|
|
|
|
output_wav_path = os.path.join(split_output_dir, f"{i}.wav") |
|
|
try: |
|
|
torchaudio.save(output_wav_path, audio_tensor, sample_rate_val) |
|
|
except Exception as e: |
|
|
print(f"Failed to save wav for sample {i}: {e}") |
|
|
continue |
|
|
|
|
|
|
|
|
sample_dict = {k: sample[k] for k in sample.keys()} |
|
|
sample_dict["audio_filepath"] = output_wav_path |
|
|
final_samples.append(sample_dict) |
|
|
else: |
|
|
print(f"Sample {i} processing failed, no audio generated.") |
|
|
|
|
|
|
|
|
with open(progress_file, "w") as f: |
|
|
f.write(str(i + 1)) |
|
|
|
|
|
|
|
|
final_dataset_obj = Dataset.from_list(final_samples) |
|
|
final_dataset_save_path = os.path.join(split_output_dir, "final_dataset") |
|
|
final_dataset_obj.save_to_disk(final_dataset_save_path) |
|
|
|
|
|
print(f"Finished processing split: {split_name} with {len(final_samples)} final samples.") |
|
|
final_dataset_dict[split_name] = final_dataset_obj |
|
|
|
|
|
print("所有分割处理完毕,最终数据集已保存。") |
|
|
|