File size: 6,602 Bytes
19891ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import os
import random
import torch
import torchaudio
from datasets import load_dataset, Dataset
import sys
from tqdm import tqdm

sys.path.append('/root/autodl-tmp/CosyVoice')
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
from cosyvoice.utils.file_utils import load_wav

# ------------------------
# 配置参数
# ------------------------
COMMON_VOICE_LANGUAGE = "en"
DATASET_NAME = "commonsense_qa"
OUTPUT_DATASET_PATH = './commonsense_qa_with_audio'  # 输出目录
SAMPLE_RATE = 16000

# ------------------------
# 辅助函数
# ------------------------
def get_random_prompt(common_voice_dataset, sample_rate=SAMPLE_RATE):
    """
    从 VoxPopuli (此处替代原 Common Voice) 数据集中随机抽取一条语音及对应文本作为 prompt。
    """
    idx = random.randint(0, len(common_voice_dataset) - 1)
    sample = common_voice_dataset.select([idx])[0]
    audio = sample['audio']
    waveform = torch.tensor(audio['array'], dtype=torch.float32)
    sr = audio['sampling_rate']
    if sr != sample_rate:
        resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate)
        waveform = resampler(waveform)
    return waveform.unsqueeze(0), sample['raw_text']

def text_to_audio(query_text, cosyvoice, common_voice_dataset, stream=False):
    """
    利用 CosyVoice2 模型将输入文本转换为语音,采用随机 prompt 进行零样本推理。
    """
    try:
        prompt_speech, prompt_text = get_random_prompt(common_voice_dataset, sample_rate=SAMPLE_RATE)

        all_speech = []
        for i, j in enumerate(cosyvoice.inference_zero_shot(
            query_text,
            prompt_text,
            prompt_speech,
            stream=stream,
            text_frontend=False
        )):
            all_speech.append(j['tts_speech'])

        # 将所有生成的语音片段拼接在一起
        combined_speech = torch.cat(all_speech, dim=-1)
        sample_rate_val = cosyvoice.sample_rate

        return {
            'audio_tensor': combined_speech,
            'sample_rate': sample_rate_val
        }
    except Exception as e:
        print(f"Error converting text to audio: {e}")
        return None

def process_example(example, cosyvoice, common_voice_dataset, sample_rate=SAMPLE_RATE):
    """
    针对 Commonsense QA 数据集中的单个样本进行 TTS 处理。
    在此示例中,仅对 sample['question'] 字段执行 TTS。
    """
    query = example['question']
    audio_result = text_to_audio(query, cosyvoice, common_voice_dataset, stream=False)
    if audio_result is not None:
        return {
            'audio_tensor': audio_result['audio_tensor'],
            'sample_rate': audio_result['sample_rate']
        }
    else:
        return None

# ------------------------
# 数据加载与模型初始化
# ------------------------
print("Loading VoxPopuli (as Common Voice) dataset...")
common_voice = load_dataset("facebook/voxpopuli", "en", split='train')
print(f"Total VoxPopuli {COMMON_VOICE_LANGUAGE} samples: {len(common_voice)}")

print("Initializing CosyVoice2 model...")
cosyvoice = CosyVoice2(
    '/root/autodl-tmp/CosyVoice/pretrained_models/CosyVoice2-0.5B',  # 替换为实际模型路径
    load_jit=True,
    load_trt=False,
    fp16=False
)

print("Loading Commonsense QA dataset...")
dataset = load_dataset("tau/commonsense_qa")
# 如果只想处理 train,可写成 dataset = load_dataset("tau/commonsense_qa", split="train")

# 创建输出目录
os.makedirs(OUTPUT_DATASET_PATH, exist_ok=True)

# ------------------------
# 主处理循环
# ------------------------
final_dataset_dict = {}  # 存放各 split 最终处理后的数据

for split_name, split_dataset in dataset.items():
    print(f"Processing split: {split_name} with {len(split_dataset)} examples")
    split_output_dir = os.path.join(OUTPUT_DATASET_PATH, split_name)
    os.makedirs(split_output_dir, exist_ok=True)
    
    # 用于断点续跑的进度记录
    progress_file = os.path.join(split_output_dir, "progress.txt")
    start_index = 0
    if os.path.exists(progress_file):
        try:
            with open(progress_file, "r") as f:
                start_index = int(f.read().strip())
            print(f"Resuming split '{split_name}' from sample index {start_index}")
        except Exception as e:
            print(f"读取进度文件失败:{e}")
    
    final_samples = []

    # 遍历处理每条样本
    for i in tqdm(range(len(split_dataset)), desc=f"Processing {split_name}"):
        # 如果已处理过,就直接跳过并仅把已存在文件的元信息记入 final_samples
        if i < start_index:
            sample = split_dataset[i]
            wav_path = os.path.join(split_output_dir, f"{i}.wav")
            if os.path.exists(wav_path):
                # 保留所有原始字段 + 音频路径
                sample_dict = {k: sample[k] for k in sample.keys()}
                sample_dict["audio_filepath"] = wav_path
                final_samples.append(sample_dict)
            continue
        
        sample = split_dataset[i]
        result = process_example(sample, cosyvoice, common_voice, sample_rate=SAMPLE_RATE)
        
        if result is not None:
            audio_tensor = result['audio_tensor']
            if audio_tensor.dim() == 1:
                audio_tensor = audio_tensor.unsqueeze(0)
            sample_rate_val = result['sample_rate']

            output_wav_path = os.path.join(split_output_dir, f"{i}.wav")
            try:
                torchaudio.save(output_wav_path, audio_tensor, sample_rate_val)
            except Exception as e:
                print(f"Failed to save wav for sample {i}: {e}")
                continue

            # 保留所有原始字段 + 生成的音频路径
            sample_dict = {k: sample[k] for k in sample.keys()}
            sample_dict["audio_filepath"] = output_wav_path
            final_samples.append(sample_dict)
        else:
            print(f"Sample {i} processing failed, no audio generated.")
        
        # 更新进度记录
        with open(progress_file, "w") as f:
            f.write(str(i + 1))
    
    # 生成 Hugging Face Dataset 并落盘
    final_dataset_obj = Dataset.from_list(final_samples)
    final_dataset_save_path = os.path.join(split_output_dir, "final_dataset")
    final_dataset_obj.save_to_disk(final_dataset_save_path)

    print(f"Finished processing split: {split_name} with {len(final_samples)} final samples.")
    final_dataset_dict[split_name] = final_dataset_obj

print("所有分割处理完毕,最终数据集已保存。")