Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +27 -0
- r1-a/dataset/ai2_arc.py +175 -0
- r1-a/dataset/alpaca.py +346 -0
- r1-a/dataset/commonsense.py +175 -0
- r1-a/dataset/examqa.py +440 -0
- r1-a/dataset/examqa_rewrite.py +487 -0
- r1-a/dataset/final_tts.py +316 -0
- r1-a/dataset/gsm8k.py +169 -0
- r1-a/dataset/gsm8k_with_audio/test/299.wav +3 -0
- r1-a/dataset/gsm8k_with_audio/test/301.wav +3 -0
- r1-a/dataset/gsm8k_with_audio/test/302.wav +3 -0
- r1-a/dataset/gsm8k_with_audio/test/314.wav +3 -0
- r1-a/dataset/gsm8k_with_audio/test/316.wav +3 -0
- r1-a/dataset/gsm8k_with_audio/test/350.wav +3 -0
- r1-a/dataset/gsm8k_with_audio/test/358.wav +3 -0
- r1-a/dataset/gsm8k_with_audio/test/359.wav +3 -0
- r1-a/dataset/gsm8k_with_audio/test/369.wav +3 -0
- r1-a/dataset/gsm8k_with_audio/test/372.wav +3 -0
- r1-a/dataset/gsm8k_with_audio/test/376.wav +3 -0
- r1-a/dataset/gsm8k_with_audio/test/385.wav +3 -0
- r1-a/dataset/gsm8k_with_audio/test/394.wav +3 -0
- r1-a/dataset/gsm8k_with_audio/test/395.wav +3 -0
- r1-a/dataset/gsm8k_with_audio/test/397.wav +3 -0
- r1-a/dataset/gsm8k_with_audio/test/400.wav +3 -0
- r1-a/dataset/gsm8k_with_audio/test/401.wav +3 -0
- r1-a/dataset/gsm8k_with_audio/test/447.wav +3 -0
- r1-a/dataset/gsm8k_with_audio/test/45.wav +3 -0
- r1-a/dataset/gsm8k_with_audio/test/450.wav +3 -0
- r1-a/dataset/gsm8k_with_audio/test/454.wav +3 -0
- r1-a/dataset/gsm8k_with_audio/test/457.wav +3 -0
- r1-a/dataset/gsm8k_with_audio/test/458.wav +3 -0
- r1-a/dataset/gsm8k_with_audio/test/459.wav +3 -0
- r1-a/dataset/gsm8k_with_audio/test/463.wav +3 -0
- r1-a/dataset/gsm8k_with_audio/test/465.wav +3 -0
- r1-a/dataset/gsm8k_with_audio/test/515.wav +3 -0
- r1-a/dataset/gsm8k_with_audio/test/877.wav +0 -0
- r1-a/dataset/gsm8k_with_audio/test/964.wav +0 -0
- r1-a/dataset/gsm8k_with_audio/test/final_dataset/dataset_info.json +20 -0
- r1-a/dataset/gsm8k_with_audio/test/final_dataset/state.json +13 -0
- r1-a/dataset/gsm8k_with_audio/test/progress.txt +1 -0
- r1-a/dataset/pkusafe.py +171 -0
- r1-a/dataset/pkusafe_tts.py +279 -0
- r1-a/dataset/retry_rewrite.py +442 -0
- r1-a/dataset/retts.py +559 -0
- r1-a/dataset/sciq.py +176 -0
- r1-a/dataset/shp.py +148 -0
- r1-a/dataset/shp_tts.py +494 -0
- r1-a/dataset/ultrachat.py +261 -0
- r1-a/dataset/ultrachat_tts.py +382 -0
- r1-a/prompt_only_examine.py +48 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,30 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
r1-a/dataset/gsm8k_with_audio/test/450.wav filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
r1-a/dataset/gsm8k_with_audio/test/45.wav filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
r1-a/dataset/gsm8k_with_audio/test/358.wav filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
r1-a/dataset/gsm8k_with_audio/test/369.wav filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
r1-a/dataset/gsm8k_with_audio/test/316.wav filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
r1-a/dataset/gsm8k_with_audio/test/454.wav filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
r1-a/dataset/gsm8k_with_audio/test/376.wav filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
r1-a/dataset/gsm8k_with_audio/test/395.wav filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
r1-a/dataset/gsm8k_with_audio/test/359.wav filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
r1-a/dataset/gsm8k_with_audio/test/447.wav filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
r1-a/dataset/gsm8k_with_audio/test/299.wav filter=lfs diff=lfs merge=lfs -text
|
| 47 |
+
r1-a/dataset/gsm8k_with_audio/test/302.wav filter=lfs diff=lfs merge=lfs -text
|
| 48 |
+
r1-a/dataset/gsm8k_with_audio/test/394.wav filter=lfs diff=lfs merge=lfs -text
|
| 49 |
+
r1-a/dataset/gsm8k_with_audio/test/350.wav filter=lfs diff=lfs merge=lfs -text
|
| 50 |
+
r1-a/dataset/gsm8k_with_audio/test/385.wav filter=lfs diff=lfs merge=lfs -text
|
| 51 |
+
r1-a/dataset/gsm8k_with_audio/test/401.wav filter=lfs diff=lfs merge=lfs -text
|
| 52 |
+
r1-a/dataset/gsm8k_with_audio/test/463.wav filter=lfs diff=lfs merge=lfs -text
|
| 53 |
+
r1-a/dataset/gsm8k_with_audio/test/458.wav filter=lfs diff=lfs merge=lfs -text
|
| 54 |
+
r1-a/dataset/gsm8k_with_audio/test/457.wav filter=lfs diff=lfs merge=lfs -text
|
| 55 |
+
r1-a/dataset/gsm8k_with_audio/test/515.wav filter=lfs diff=lfs merge=lfs -text
|
| 56 |
+
r1-a/dataset/gsm8k_with_audio/test/301.wav filter=lfs diff=lfs merge=lfs -text
|
| 57 |
+
r1-a/dataset/gsm8k_with_audio/test/372.wav filter=lfs diff=lfs merge=lfs -text
|
| 58 |
+
r1-a/dataset/gsm8k_with_audio/test/314.wav filter=lfs diff=lfs merge=lfs -text
|
| 59 |
+
r1-a/dataset/gsm8k_with_audio/test/397.wav filter=lfs diff=lfs merge=lfs -text
|
| 60 |
+
r1-a/dataset/gsm8k_with_audio/test/465.wav filter=lfs diff=lfs merge=lfs -text
|
| 61 |
+
r1-a/dataset/gsm8k_with_audio/test/459.wav filter=lfs diff=lfs merge=lfs -text
|
| 62 |
+
r1-a/dataset/gsm8k_with_audio/test/400.wav filter=lfs diff=lfs merge=lfs -text
|
r1-a/dataset/ai2_arc.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import random
|
| 3 |
+
import torch
|
| 4 |
+
import torchaudio
|
| 5 |
+
from datasets import load_dataset, Dataset
|
| 6 |
+
import sys
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
|
| 9 |
+
sys.path.append('/root/autodl-tmp/CosyVoice')
|
| 10 |
+
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
|
| 11 |
+
from cosyvoice.utils.file_utils import load_wav
|
| 12 |
+
|
| 13 |
+
# ------------------------
|
| 14 |
+
# 配置参数
|
| 15 |
+
# ------------------------
|
| 16 |
+
COMMON_VOICE_LANGUAGE = "en"
|
| 17 |
+
DATASET_NAME = "ai2_arc"
|
| 18 |
+
OUTPUT_DATASET_PATH = './arc_easy_with_audio' # 输出目录
|
| 19 |
+
SAMPLE_RATE = 16000
|
| 20 |
+
|
| 21 |
+
# ------------------------
|
| 22 |
+
# 辅助函数
|
| 23 |
+
# ------------------------
|
| 24 |
+
def get_random_prompt(common_voice_dataset, sample_rate=SAMPLE_RATE):
|
| 25 |
+
"""
|
| 26 |
+
从 VoxPopuli (替代原 Common Voice) 数据集中随机抽取一条语音及对应文本作为 prompt。
|
| 27 |
+
"""
|
| 28 |
+
idx = random.randint(0, len(common_voice_dataset) - 1)
|
| 29 |
+
sample = common_voice_dataset.select([idx])[0]
|
| 30 |
+
audio = sample['audio']
|
| 31 |
+
waveform = torch.tensor(audio['array'], dtype=torch.float32)
|
| 32 |
+
sr = audio['sampling_rate']
|
| 33 |
+
if sr != sample_rate:
|
| 34 |
+
resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate)
|
| 35 |
+
waveform = resampler(waveform)
|
| 36 |
+
return waveform.unsqueeze(0), sample['raw_text']
|
| 37 |
+
|
| 38 |
+
def text_to_audio(query_text, cosyvoice, common_voice_dataset, stream=False):
|
| 39 |
+
"""
|
| 40 |
+
利用 CosyVoice2 模型将输入文本转换为语音,采用随机 prompt 进行零样本推理。
|
| 41 |
+
"""
|
| 42 |
+
try:
|
| 43 |
+
prompt_speech, prompt_text = get_random_prompt(common_voice_dataset, sample_rate=SAMPLE_RATE)
|
| 44 |
+
|
| 45 |
+
all_speech = []
|
| 46 |
+
for i, j in enumerate(cosyvoice.inference_zero_shot(
|
| 47 |
+
query_text,
|
| 48 |
+
prompt_text,
|
| 49 |
+
prompt_speech,
|
| 50 |
+
stream=stream,
|
| 51 |
+
text_frontend=False
|
| 52 |
+
)):
|
| 53 |
+
all_speech.append(j['tts_speech'])
|
| 54 |
+
|
| 55 |
+
# 将所有生成的语音片段拼接在一起
|
| 56 |
+
combined_speech = torch.cat(all_speech, dim=-1)
|
| 57 |
+
sample_rate_val = cosyvoice.sample_rate
|
| 58 |
+
|
| 59 |
+
return {
|
| 60 |
+
'audio_tensor': combined_speech,
|
| 61 |
+
'sample_rate': sample_rate_val
|
| 62 |
+
}
|
| 63 |
+
except Exception as e:
|
| 64 |
+
print(f"Error converting text to audio: {e}")
|
| 65 |
+
return None
|
| 66 |
+
|
| 67 |
+
def process_example(example, cosyvoice, common_voice_dataset, sample_rate=SAMPLE_RATE):
|
| 68 |
+
"""
|
| 69 |
+
针对 AI2 ARC 数据集中的单个样本进行 TTS 处理。
|
| 70 |
+
在此示例中,仅对 sample['question'] 字段执行 TTS。
|
| 71 |
+
"""
|
| 72 |
+
query = example['question']
|
| 73 |
+
audio_result = text_to_audio(query, cosyvoice, common_voice_dataset, stream=False)
|
| 74 |
+
if audio_result is not None:
|
| 75 |
+
return {
|
| 76 |
+
'audio_tensor': audio_result['audio_tensor'],
|
| 77 |
+
'sample_rate': audio_result['sample_rate']
|
| 78 |
+
}
|
| 79 |
+
else:
|
| 80 |
+
return None
|
| 81 |
+
|
| 82 |
+
# ------------------------
|
| 83 |
+
# 数据加载与模型初始化
|
| 84 |
+
# ------------------------
|
| 85 |
+
print("Loading VoxPopuli (as Common Voice) dataset...")
|
| 86 |
+
common_voice = load_dataset("facebook/voxpopuli", "en", split='train')
|
| 87 |
+
print(f"Total VoxPopuli {COMMON_VOICE_LANGUAGE} samples: {len(common_voice)}")
|
| 88 |
+
|
| 89 |
+
print("Initializing CosyVoice2 model...")
|
| 90 |
+
cosyvoice = CosyVoice2(
|
| 91 |
+
'/root/autodl-tmp/CosyVoice/pretrained_models/CosyVoice2-0.5B', # 替换为实际模型路径
|
| 92 |
+
load_jit=True,
|
| 93 |
+
load_trt=False,
|
| 94 |
+
fp16=False
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
print("Loading ARC-Challenge dataset...")
|
| 98 |
+
# 如果想处理 ARC-Easy,只需改为 "ARC-Easy"
|
| 99 |
+
dataset = load_dataset("allenai/ai2_arc", "ARC-Easy")
|
| 100 |
+
|
| 101 |
+
# 创建输出目录
|
| 102 |
+
os.makedirs(OUTPUT_DATASET_PATH, exist_ok=True)
|
| 103 |
+
|
| 104 |
+
# ------------------------
|
| 105 |
+
# 主处理循环
|
| 106 |
+
# ------------------------
|
| 107 |
+
final_dataset_dict = {} # 存放各 split 最终处理后的数据
|
| 108 |
+
|
| 109 |
+
for split_name, split_dataset in dataset.items():
|
| 110 |
+
print(f"Processing split: {split_name} with {len(split_dataset)} examples")
|
| 111 |
+
split_output_dir = os.path.join(OUTPUT_DATASET_PATH, split_name)
|
| 112 |
+
os.makedirs(split_output_dir, exist_ok=True)
|
| 113 |
+
|
| 114 |
+
# 用于断点续跑的进度记录
|
| 115 |
+
progress_file = os.path.join(split_output_dir, "progress.txt")
|
| 116 |
+
start_index = 0
|
| 117 |
+
if os.path.exists(progress_file):
|
| 118 |
+
try:
|
| 119 |
+
with open(progress_file, "r") as f:
|
| 120 |
+
start_index = int(f.read().strip())
|
| 121 |
+
print(f"Resuming split '{split_name}' from sample index {start_index}")
|
| 122 |
+
except Exception as e:
|
| 123 |
+
print(f"读取进度文件失败:{e}")
|
| 124 |
+
|
| 125 |
+
final_samples = []
|
| 126 |
+
|
| 127 |
+
# 遍历处理每条样本
|
| 128 |
+
for i in tqdm(range(len(split_dataset)), desc=f"Processing {split_name}"):
|
| 129 |
+
# 如果已处理过,就直接跳过并仅把已存在文件的元信息记入 final_samples
|
| 130 |
+
if i < start_index:
|
| 131 |
+
sample = split_dataset[i]
|
| 132 |
+
wav_path = os.path.join(split_output_dir, f"{i}.wav")
|
| 133 |
+
if os.path.exists(wav_path):
|
| 134 |
+
# 保留所有原始字段 + 音频路径
|
| 135 |
+
sample_dict = {k: sample[k] for k in sample.keys()}
|
| 136 |
+
sample_dict["audio_filepath"] = wav_path
|
| 137 |
+
final_samples.append(sample_dict)
|
| 138 |
+
continue
|
| 139 |
+
|
| 140 |
+
sample = split_dataset[i]
|
| 141 |
+
result = process_example(sample, cosyvoice, common_voice, sample_rate=SAMPLE_RATE)
|
| 142 |
+
|
| 143 |
+
if result is not None:
|
| 144 |
+
audio_tensor = result['audio_tensor']
|
| 145 |
+
if audio_tensor.dim() == 1:
|
| 146 |
+
audio_tensor = audio_tensor.unsqueeze(0)
|
| 147 |
+
sample_rate_val = result['sample_rate']
|
| 148 |
+
|
| 149 |
+
output_wav_path = os.path.join(split_output_dir, f"{i}.wav")
|
| 150 |
+
try:
|
| 151 |
+
torchaudio.save(output_wav_path, audio_tensor, sample_rate_val)
|
| 152 |
+
except Exception as e:
|
| 153 |
+
print(f"Failed to save wav for sample {i}: {e}")
|
| 154 |
+
continue
|
| 155 |
+
|
| 156 |
+
# 保留所有原始字段 + 生成的音频路径
|
| 157 |
+
sample_dict = {k: sample[k] for k in sample.keys()}
|
| 158 |
+
sample_dict["audio_filepath"] = output_wav_path
|
| 159 |
+
final_samples.append(sample_dict)
|
| 160 |
+
else:
|
| 161 |
+
print(f"Sample {i} processing failed, no audio generated.")
|
| 162 |
+
|
| 163 |
+
# 更新进度记录
|
| 164 |
+
with open(progress_file, "w") as f:
|
| 165 |
+
f.write(str(i + 1))
|
| 166 |
+
|
| 167 |
+
# 生成 Hugging Face Dataset 并落盘
|
| 168 |
+
final_dataset_obj = Dataset.from_list(final_samples)
|
| 169 |
+
final_dataset_save_path = os.path.join(split_output_dir, "final_dataset")
|
| 170 |
+
final_dataset_obj.save_to_disk(final_dataset_save_path)
|
| 171 |
+
|
| 172 |
+
print(f"Finished processing split: {split_name} with {len(final_samples)} final samples.")
|
| 173 |
+
final_dataset_dict[split_name] = final_dataset_obj
|
| 174 |
+
|
| 175 |
+
print("所有分割处理完毕,最终数据集已保存。")
|
r1-a/dataset/alpaca.py
ADDED
|
@@ -0,0 +1,346 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --- SET CUDA DEVICE ---
|
| 2 |
+
# Method 1: Set environment variable BEFORE importing torch/cosyvoice
|
| 3 |
+
# This makes only GPU 1 visible to the script, appearing as 'cuda:0' internally.
|
| 4 |
+
import os
|
| 5 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
|
| 6 |
+
# --- End CUDA Device Setting ---
|
| 7 |
+
|
| 8 |
+
import random
|
| 9 |
+
import torch
|
| 10 |
+
import torchaudio
|
| 11 |
+
from datasets import load_dataset, Dataset, load_from_disk
|
| 12 |
+
import sys
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
import time
|
| 15 |
+
|
| 16 |
+
# Check if the specified GPU is available after setting the environment variable
|
| 17 |
+
if not torch.cuda.is_available():
|
| 18 |
+
print("WARNING: CUDA is not available after setting CUDA_VISIBLE_DEVICES='1'. Check your PyTorch installation and GPU drivers.")
|
| 19 |
+
print("Attempting to run on CPU, but this will be very slow.")
|
| 20 |
+
# Decide if you want to exit or proceed on CPU
|
| 21 |
+
# sys.exit(1) # Uncomment to exit if GPU not found
|
| 22 |
+
effective_device = torch.device("cpu")
|
| 23 |
+
else:
|
| 24 |
+
# Since CUDA_VISIBLE_DEVICES is set to '1', the first *visible* device is cuda:0
|
| 25 |
+
effective_device = torch.device("cuda:0")
|
| 26 |
+
print(f"CUDA device visible to PyTorch: {torch.cuda.get_device_name(0)}") # Should show the name of GPU 1
|
| 27 |
+
print(f"Script will effectively run on: {effective_device}")
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
sys.path.append('/root/autodl-tmp/CosyVoice') # Make sure this path is correct
|
| 31 |
+
# Import CosyVoice *after* setting the environment variable
|
| 32 |
+
try:
|
| 33 |
+
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
|
| 34 |
+
from cosyvoice.utils.file_utils import load_wav
|
| 35 |
+
except ImportError as e:
|
| 36 |
+
print(f"Error importing CosyVoice: {e}")
|
| 37 |
+
print("Please ensure the path '/root/autodl-tmp/CosyVoice' is correct and the library is installed.")
|
| 38 |
+
sys.exit(1)
|
| 39 |
+
|
| 40 |
+
# ------------------------
|
| 41 |
+
# 配置参数
|
| 42 |
+
# ------------------------
|
| 43 |
+
COMMON_VOICE_LANGUAGE = "en"
|
| 44 |
+
FILTERED_ALPACA_PATH = './alpaca_filtered_for_spoken_dialogue_v2'
|
| 45 |
+
SPLITS_TO_PROCESS = ['train']
|
| 46 |
+
OUTPUT_DATASET_PATH = './alpaca_filtered_spoken_with_output_audio' # Keep output path distinct
|
| 47 |
+
SAMPLE_RATE = 16000
|
| 48 |
+
MAX_TTS_RETRIES = 3
|
| 49 |
+
RETRY_DELAY_SECONDS = 2
|
| 50 |
+
|
| 51 |
+
# ------------------------
|
| 52 |
+
# 辅助函数 (No changes needed here, should run on the visible device)
|
| 53 |
+
# ------------------------
|
| 54 |
+
def get_random_prompt(common_voice_dataset, sample_rate=SAMPLE_RATE):
|
| 55 |
+
"""
|
| 56 |
+
从 VoxPopuli (替代原 Common Voice) 数据集中随机抽取一条语音及对应文本作为 prompt。
|
| 57 |
+
"""
|
| 58 |
+
idx = random.randint(0, len(common_voice_dataset) - 1)
|
| 59 |
+
sample = common_voice_dataset.select([idx])[0]
|
| 60 |
+
audio = sample['audio']
|
| 61 |
+
waveform = torch.tensor(audio['array'], dtype=torch.float32) # Created on CPU
|
| 62 |
+
sr = audio['sampling_rate']
|
| 63 |
+
if sr != sample_rate:
|
| 64 |
+
if waveform.dim() > 1:
|
| 65 |
+
waveform = waveform.mean(dim=0)
|
| 66 |
+
resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate)
|
| 67 |
+
waveform = resampler(waveform)
|
| 68 |
+
if waveform.dim() == 1:
|
| 69 |
+
waveform = waveform.unsqueeze(0)
|
| 70 |
+
if waveform.numel() == 0 or not sample['raw_text']:
|
| 71 |
+
print("Warning: Got an empty prompt, trying again...")
|
| 72 |
+
return get_random_prompt(common_voice_dataset, sample_rate)
|
| 73 |
+
# Return CPU tensor, CosyVoice inference should handle moving it
|
| 74 |
+
return waveform, sample['raw_text']
|
| 75 |
+
|
| 76 |
+
def text_to_audio(query_text, cosyvoice, common_voice_dataset, stream=False, max_retries=MAX_TTS_RETRIES):
|
| 77 |
+
"""
|
| 78 |
+
利用 CosyVoice2 模型将输入文本转换为语音,采用随机 prompt 进行零样本推理。
|
| 79 |
+
Includes retry logic on failure. Assumes cosyvoice runs on the configured device.
|
| 80 |
+
"""
|
| 81 |
+
last_exception = None
|
| 82 |
+
for attempt in range(max_retries):
|
| 83 |
+
try:
|
| 84 |
+
# prompt_speech is initially on CPU
|
| 85 |
+
prompt_speech, prompt_text = get_random_prompt(common_voice_dataset, sample_rate=SAMPLE_RATE)
|
| 86 |
+
|
| 87 |
+
all_speech = []
|
| 88 |
+
# cosyvoice.inference_zero_shot should internally use the GPU device it was initialized on
|
| 89 |
+
# (which should be the visible cuda:0, i.e., original cuda:1)
|
| 90 |
+
inference_generator = cosyvoice.inference_zero_shot(
|
| 91 |
+
query_text,
|
| 92 |
+
prompt_text,
|
| 93 |
+
prompt_speech, # Pass CPU tensor
|
| 94 |
+
stream=stream,
|
| 95 |
+
text_frontend=False
|
| 96 |
+
)
|
| 97 |
+
# Generated chunks 'tts_speech' will be on the GPU
|
| 98 |
+
for i, chunk in enumerate(inference_generator):
|
| 99 |
+
if 'tts_speech' in chunk and chunk['tts_speech'] is not None:
|
| 100 |
+
all_speech.append(chunk['tts_speech'])
|
| 101 |
+
else:
|
| 102 |
+
print(f"Warning: Chunk {i} missing 'tts_speech' for text '{query_text[:60]}...'")
|
| 103 |
+
|
| 104 |
+
if not all_speech:
|
| 105 |
+
# Clear GPU memory cache if an error occurs during generation
|
| 106 |
+
if torch.cuda.is_available(): torch.cuda.empty_cache()
|
| 107 |
+
raise ValueError("TTS inference finished but produced no audio chunks.")
|
| 108 |
+
|
| 109 |
+
# combined_speech is on GPU
|
| 110 |
+
combined_speech = torch.cat(all_speech, dim=-1)
|
| 111 |
+
sample_rate_val = cosyvoice.sample_rate
|
| 112 |
+
|
| 113 |
+
return {
|
| 114 |
+
# Return GPU tensor, will be moved to CPU before saving
|
| 115 |
+
'audio_tensor': combined_speech,
|
| 116 |
+
'sample_rate': sample_rate_val
|
| 117 |
+
}
|
| 118 |
+
except Exception as e:
|
| 119 |
+
last_exception = e
|
| 120 |
+
print(f"Error converting text to audio on attempt {attempt + 1}/{max_retries}: {e}")
|
| 121 |
+
print(f"Text: '{query_text[:100]}...'")
|
| 122 |
+
print(f"Prompt Text: '{prompt_text[:100]}...'")
|
| 123 |
+
# Clear GPU cache on error as well
|
| 124 |
+
if torch.cuda.is_available(): torch.cuda.empty_cache()
|
| 125 |
+
if attempt < max_retries - 1:
|
| 126 |
+
print(f"Retrying with a different prompt in {RETRY_DELAY_SECONDS}s...")
|
| 127 |
+
time.sleep(RETRY_DELAY_SECONDS)
|
| 128 |
+
else:
|
| 129 |
+
print(f"All {max_retries} TTS attempts failed.")
|
| 130 |
+
|
| 131 |
+
print(f"Failed to generate audio for text after {max_retries} attempts: '{query_text[:60]}...'")
|
| 132 |
+
print(f"Last error: {last_exception}")
|
| 133 |
+
return None
|
| 134 |
+
|
| 135 |
+
def process_example(example, cosyvoice, common_voice_dataset, sample_rate=SAMPLE_RATE):
|
| 136 |
+
"""
|
| 137 |
+
针对从磁盘加载的过滤后 Alpaca 数据集中的单个样本进行 TTS 处理。
|
| 138 |
+
Processes example['output'].
|
| 139 |
+
"""
|
| 140 |
+
text_to_convert = example.get('instruction')+example.get('input')
|
| 141 |
+
if not text_to_convert or not isinstance(text_to_convert, str) or text_to_convert.strip() == "":
|
| 142 |
+
print(f"Warning: Skipping example due to missing or empty 'output' field: {example.keys()}")
|
| 143 |
+
return None
|
| 144 |
+
|
| 145 |
+
audio_result = text_to_audio(text_to_convert, cosyvoice, common_voice_dataset, stream=False)
|
| 146 |
+
|
| 147 |
+
if audio_result is not None:
|
| 148 |
+
audio_tensor = audio_result['audio_tensor'] # Still on GPU here
|
| 149 |
+
if audio_tensor.dim() == 1:
|
| 150 |
+
audio_tensor = audio_tensor.unsqueeze(0)
|
| 151 |
+
elif audio_tensor.dim() > 2:
|
| 152 |
+
print(f"Warning: Generated audio tensor has unexpected shape {audio_tensor.shape}. Attempting to flatten.")
|
| 153 |
+
audio_tensor = audio_tensor.view(1, -1)
|
| 154 |
+
|
| 155 |
+
if audio_tensor.numel() == 0:
|
| 156 |
+
print(f"Warning: Generated audio tensor is empty for output text: '{text_to_convert[:60]}...'")
|
| 157 |
+
# Clear GPU cache even for empty tensor? Maybe not needed.
|
| 158 |
+
return None
|
| 159 |
+
|
| 160 |
+
return {
|
| 161 |
+
'audio_tensor': audio_tensor, # Return GPU tensor
|
| 162 |
+
'sample_rate': audio_result['sample_rate']
|
| 163 |
+
}
|
| 164 |
+
else:
|
| 165 |
+
return None
|
| 166 |
+
|
| 167 |
+
# ------------------------
|
| 168 |
+
# 数据加载与模型初始化
|
| 169 |
+
# ------------------------
|
| 170 |
+
print("Loading VoxPopuli (as Common Voice) dataset for prompts...")
|
| 171 |
+
try:
|
| 172 |
+
# Load prompt dataset to CPU memory
|
| 173 |
+
common_voice = load_dataset("facebook/voxpopuli", "en", split='train')
|
| 174 |
+
print(f"Total VoxPopuli {COMMON_VOICE_LANGUAGE} samples: {len(common_voice)}")
|
| 175 |
+
if len(common_voice) == 0:
|
| 176 |
+
raise ValueError("VoxPopuli dataset loaded but contains no samples.")
|
| 177 |
+
except Exception as e:
|
| 178 |
+
print(f"Error loading VoxPopuli dataset: {e}")
|
| 179 |
+
sys.exit(1)
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
print("Initializing CosyVoice2 model...")
|
| 183 |
+
try:
|
| 184 |
+
# CosyVoice should automatically initialize on the visible device ('cuda:0', which is original 'cuda:1')
|
| 185 |
+
# No explicit device='cuda:1' needed here due to CUDA_VISIBLE_DEVICES
|
| 186 |
+
cosyvoice = CosyVoice2(
|
| 187 |
+
'/root/autodl-tmp/CosyVoice/pretrained_models/CosyVoice2-0.5B',
|
| 188 |
+
load_jit=True,
|
| 189 |
+
load_trt=False, # Ensure TRT is False if not set up for GPU 1
|
| 190 |
+
fp16=False # Check if GPU 1 supports FP16 well if you enable this
|
| 191 |
+
# device=effective_device # Usually not needed if CUDA_VISIBLE_DEVICES is set, but uncomment if CosyVoice requires it explicitly
|
| 192 |
+
)
|
| 193 |
+
print(f"CosyVoice model initialized. It should be using device: {effective_device}")
|
| 194 |
+
except Exception as e:
|
| 195 |
+
print(f"Error initializing CosyVoice2 model: {e}")
|
| 196 |
+
# Try to get more info if it's a CUDA error
|
| 197 |
+
if isinstance(e, RuntimeError) and 'CUDA' in str(e):
|
| 198 |
+
print("This might be a CUDA initialization error. Ensure GPU 1 is functional and has enough memory.")
|
| 199 |
+
sys.exit(1)
|
| 200 |
+
|
| 201 |
+
print(f"Loading pre-filtered Alpaca dataset(s) from disk: {FILTERED_ALPACA_PATH}")
|
| 202 |
+
dataset_dict = {}
|
| 203 |
+
loaded_splits_count = 0
|
| 204 |
+
for split_name in SPLITS_TO_PROCESS:
|
| 205 |
+
split_dir_name = f"{split_name}_dataset"
|
| 206 |
+
split_path = os.path.join(FILTERED_ALPACA_PATH, split_dir_name)
|
| 207 |
+
print(f"Attempting to load split '{split_name}' from: {split_path}")
|
| 208 |
+
try:
|
| 209 |
+
# Load dataset to CPU memory
|
| 210 |
+
split_dataset = load_from_disk(split_path)
|
| 211 |
+
|
| 212 |
+
if not split_dataset:
|
| 213 |
+
print(f"Warning: Dataset loaded from '{split_path}' is empty or invalid. Skipping this split.")
|
| 214 |
+
continue
|
| 215 |
+
dataset_dict[split_name] = split_dataset
|
| 216 |
+
print(f"Successfully loaded split '{split_name}' with {len(split_dataset)} examples.")
|
| 217 |
+
loaded_splits_count += 1
|
| 218 |
+
except FileNotFoundError:
|
| 219 |
+
print(f"Info: Filtered dataset split not found at '{split_path}'. Skipping this split.")
|
| 220 |
+
except Exception as e:
|
| 221 |
+
print(f"Error loading pre-filtered dataset split from '{split_path}': {e}. Skipping this split.")
|
| 222 |
+
|
| 223 |
+
if loaded_splits_count == 0:
|
| 224 |
+
print(f"Error: Could not load any dataset splits from '{FILTERED_ALPACA_PATH}' using splits '{SPLITS_TO_PROCESS}'.")
|
| 225 |
+
sys.exit(1)
|
| 226 |
+
|
| 227 |
+
# 创建输出目录
|
| 228 |
+
os.makedirs(OUTPUT_DATASET_PATH, exist_ok=True)
|
| 229 |
+
|
| 230 |
+
# ------------------------
|
| 231 |
+
# 主处理循环
|
| 232 |
+
# ------------------------
|
| 233 |
+
final_dataset_dict = {}
|
| 234 |
+
|
| 235 |
+
for split_name, split_dataset in dataset_dict.items():
|
| 236 |
+
print(f"\nProcessing loaded split: {split_name} with {len(split_dataset)} examples")
|
| 237 |
+
split_output_dir = os.path.join(OUTPUT_DATASET_PATH, split_name)
|
| 238 |
+
os.makedirs(split_output_dir, exist_ok=True)
|
| 239 |
+
|
| 240 |
+
progress_file = os.path.join(split_output_dir, "progress.txt")
|
| 241 |
+
start_index = 0
|
| 242 |
+
# ... (progress file reading logic remains the same)
|
| 243 |
+
if os.path.exists(progress_file):
|
| 244 |
+
try:
|
| 245 |
+
with open(progress_file, "r") as f:
|
| 246 |
+
content = f.read().strip()
|
| 247 |
+
if content:
|
| 248 |
+
start_index = int(content)
|
| 249 |
+
print(f"Resuming split '{split_name}' TTS from sample index {start_index}")
|
| 250 |
+
else:
|
| 251 |
+
print(f"Progress file '{progress_file}' is empty, starting TTS from index 0.")
|
| 252 |
+
start_index = 0
|
| 253 |
+
except ValueError:
|
| 254 |
+
print(f"Could not parse integer from progress file '{progress_file}'. Starting TTS from index 0.")
|
| 255 |
+
start_index = 0
|
| 256 |
+
except Exception as e:
|
| 257 |
+
print(f"Error reading progress file '{progress_file}': {e}. Starting TTS from index 0.")
|
| 258 |
+
start_index = 0
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
final_samples = []
|
| 262 |
+
|
| 263 |
+
pbar = tqdm(range(start_index, len(split_dataset)), desc=f"TTS on '{split_name}' output field", initial=start_index, total=len(split_dataset))
|
| 264 |
+
for i in pbar:
|
| 265 |
+
sample = split_dataset[i] # Sample data is on CPU
|
| 266 |
+
output_wav_path = os.path.join(split_output_dir, f"{i}.wav")
|
| 267 |
+
|
| 268 |
+
if os.path.exists(output_wav_path):
|
| 269 |
+
sample_dict = {k: sample[k] for k in sample.keys()}
|
| 270 |
+
sample_dict["audio_filepath"] = output_wav_path
|
| 271 |
+
final_samples.append(sample_dict)
|
| 272 |
+
with open(progress_file, "w") as f:
|
| 273 |
+
f.write(str(i + 1))
|
| 274 |
+
continue
|
| 275 |
+
|
| 276 |
+
# --- Perform TTS on the target device ---
|
| 277 |
+
result = process_example(sample, cosyvoice, common_voice, sample_rate=SAMPLE_RATE)
|
| 278 |
+
|
| 279 |
+
if result is not None:
|
| 280 |
+
audio_tensor = result['audio_tensor'] # Received tensor is on GPU
|
| 281 |
+
sample_rate_val = result['sample_rate']
|
| 282 |
+
|
| 283 |
+
try:
|
| 284 |
+
# --- Move tensor to CPU before saving ---
|
| 285 |
+
audio_tensor_save = audio_tensor.detach().cpu().to(torch.float32)
|
| 286 |
+
if audio_tensor_save.dim() == 1:
|
| 287 |
+
audio_tensor_save = audio_tensor_save.unsqueeze(0)
|
| 288 |
+
elif audio_tensor_save.dim() > 2:
|
| 289 |
+
audio_tensor_save = audio_tensor_save.view(1, -1)
|
| 290 |
+
|
| 291 |
+
torchaudio.save(output_wav_path, audio_tensor_save, sample_rate_val)
|
| 292 |
+
|
| 293 |
+
sample_dict = {k: sample[k] for k in sample.keys()}
|
| 294 |
+
sample_dict["audio_filepath"] = output_wav_path
|
| 295 |
+
final_samples.append(sample_dict)
|
| 296 |
+
|
| 297 |
+
# --- Explicitly delete GPU tensor and clear cache periodically? ---
|
| 298 |
+
# Can sometimes help prevent memory creep in long loops
|
| 299 |
+
del audio_tensor
|
| 300 |
+
# if i % 50 == 0: # Example: clear cache every 50 iterations
|
| 301 |
+
# if torch.cuda.is_available(): torch.cuda.empty_cache()
|
| 302 |
+
|
| 303 |
+
except Exception as e:
|
| 304 |
+
print(f"Failed to save wav for sample {i} ('output' field TTS) at {output_wav_path}: {e}")
|
| 305 |
+
# Clear cache on save error too, just in case
|
| 306 |
+
if torch.cuda.is_available(): torch.cuda.empty_cache()
|
| 307 |
+
else:
|
| 308 |
+
print(f"Sample {i} TTS failed after retries (Output Text: '{sample.get('output', 'N/A')[:60]}...'), no audio generated.")
|
| 309 |
+
# No tensor to delete if result is None
|
| 310 |
+
|
| 311 |
+
# Update progress file
|
| 312 |
+
with open(progress_file, "w") as f:
|
| 313 |
+
f.write(str(i + 1))
|
| 314 |
+
|
| 315 |
+
# --- Optional: Add more frequent cache clearing ---
|
| 316 |
+
# if i % 20 == 0 and torch.cuda.is_available(): # Clear more often if memory is tight
|
| 317 |
+
# torch.cuda.empty_cache()
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
# --- Final cache clear after finishing a split ---
|
| 321 |
+
if torch.cuda.is_available():
|
| 322 |
+
torch.cuda.empty_cache()
|
| 323 |
+
|
| 324 |
+
# ... (Saving final dataset logic remains the same)
|
| 325 |
+
if final_samples:
|
| 326 |
+
final_dataset_obj = Dataset.from_list(final_samples)
|
| 327 |
+
final_dataset_save_path = os.path.join(split_output_dir, "final_dataset")
|
| 328 |
+
try:
|
| 329 |
+
print(f"Saving final dataset for split '{split_name}' (with new audio paths) to {final_dataset_save_path}...")
|
| 330 |
+
os.makedirs(os.path.dirname(final_dataset_save_path), exist_ok=True)
|
| 331 |
+
final_dataset_obj.save_to_disk(final_dataset_save_path)
|
| 332 |
+
print(f"Finished processing split: {split_name}. Saved {len(final_samples)} samples with new audio paths for 'output' field.")
|
| 333 |
+
final_dataset_dict[split_name] = final_dataset_obj
|
| 334 |
+
except Exception as e:
|
| 335 |
+
print(f"Error saving final dataset for split '{split_name}' to disk: {e}")
|
| 336 |
+
else:
|
| 337 |
+
print(f"Finished processing split: {split_name}. No samples were successfully processed or saved.")
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
print("="*30)
|
| 341 |
+
if final_dataset_dict:
|
| 342 |
+
print(f"All specified splits processed. Final datasets saved in respective subdirectories within '{OUTPUT_DATASET_PATH}'.")
|
| 343 |
+
print(f"Processed splits: {list(final_dataset_dict.keys())}")
|
| 344 |
+
else:
|
| 345 |
+
print(f"Processing finished, but no final datasets were generated or saved in '{OUTPUT_DATASET_PATH}'. Check logs for errors.")
|
| 346 |
+
print("="*30)
|
r1-a/dataset/commonsense.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import random
|
| 3 |
+
import torch
|
| 4 |
+
import torchaudio
|
| 5 |
+
from datasets import load_dataset, Dataset
|
| 6 |
+
import sys
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
|
| 9 |
+
sys.path.append('/root/autodl-tmp/CosyVoice')
|
| 10 |
+
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
|
| 11 |
+
from cosyvoice.utils.file_utils import load_wav
|
| 12 |
+
|
| 13 |
+
# ------------------------
|
| 14 |
+
# 配置参数
|
| 15 |
+
# ------------------------
|
| 16 |
+
COMMON_VOICE_LANGUAGE = "en"
|
| 17 |
+
DATASET_NAME = "commonsense_qa"
|
| 18 |
+
OUTPUT_DATASET_PATH = './commonsense_qa_with_audio' # 输出目录
|
| 19 |
+
SAMPLE_RATE = 16000
|
| 20 |
+
|
| 21 |
+
# ------------------------
|
| 22 |
+
# 辅助函数
|
| 23 |
+
# ------------------------
|
| 24 |
+
def get_random_prompt(common_voice_dataset, sample_rate=SAMPLE_RATE):
|
| 25 |
+
"""
|
| 26 |
+
从 VoxPopuli (此处替代原 Common Voice) 数据集中随机抽取一条语音及对应文本作为 prompt。
|
| 27 |
+
"""
|
| 28 |
+
idx = random.randint(0, len(common_voice_dataset) - 1)
|
| 29 |
+
sample = common_voice_dataset.select([idx])[0]
|
| 30 |
+
audio = sample['audio']
|
| 31 |
+
waveform = torch.tensor(audio['array'], dtype=torch.float32)
|
| 32 |
+
sr = audio['sampling_rate']
|
| 33 |
+
if sr != sample_rate:
|
| 34 |
+
resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate)
|
| 35 |
+
waveform = resampler(waveform)
|
| 36 |
+
return waveform.unsqueeze(0), sample['raw_text']
|
| 37 |
+
|
| 38 |
+
def text_to_audio(query_text, cosyvoice, common_voice_dataset, stream=False):
|
| 39 |
+
"""
|
| 40 |
+
利用 CosyVoice2 模型将输入文本转换为语音,采用随机 prompt 进行零样本推理。
|
| 41 |
+
"""
|
| 42 |
+
try:
|
| 43 |
+
prompt_speech, prompt_text = get_random_prompt(common_voice_dataset, sample_rate=SAMPLE_RATE)
|
| 44 |
+
|
| 45 |
+
all_speech = []
|
| 46 |
+
for i, j in enumerate(cosyvoice.inference_zero_shot(
|
| 47 |
+
query_text,
|
| 48 |
+
prompt_text,
|
| 49 |
+
prompt_speech,
|
| 50 |
+
stream=stream,
|
| 51 |
+
text_frontend=False
|
| 52 |
+
)):
|
| 53 |
+
all_speech.append(j['tts_speech'])
|
| 54 |
+
|
| 55 |
+
# 将所有生成的语音片段拼接在一起
|
| 56 |
+
combined_speech = torch.cat(all_speech, dim=-1)
|
| 57 |
+
sample_rate_val = cosyvoice.sample_rate
|
| 58 |
+
|
| 59 |
+
return {
|
| 60 |
+
'audio_tensor': combined_speech,
|
| 61 |
+
'sample_rate': sample_rate_val
|
| 62 |
+
}
|
| 63 |
+
except Exception as e:
|
| 64 |
+
print(f"Error converting text to audio: {e}")
|
| 65 |
+
return None
|
| 66 |
+
|
| 67 |
+
def process_example(example, cosyvoice, common_voice_dataset, sample_rate=SAMPLE_RATE):
|
| 68 |
+
"""
|
| 69 |
+
针对 Commonsense QA 数据集中的单个样本进行 TTS 处理。
|
| 70 |
+
在此示例中,仅对 sample['question'] 字段执行 TTS。
|
| 71 |
+
"""
|
| 72 |
+
query = example['question']
|
| 73 |
+
audio_result = text_to_audio(query, cosyvoice, common_voice_dataset, stream=False)
|
| 74 |
+
if audio_result is not None:
|
| 75 |
+
return {
|
| 76 |
+
'audio_tensor': audio_result['audio_tensor'],
|
| 77 |
+
'sample_rate': audio_result['sample_rate']
|
| 78 |
+
}
|
| 79 |
+
else:
|
| 80 |
+
return None
|
| 81 |
+
|
| 82 |
+
# ------------------------
|
| 83 |
+
# 数据加载与模型初始化
|
| 84 |
+
# ------------------------
|
| 85 |
+
print("Loading VoxPopuli (as Common Voice) dataset...")
|
| 86 |
+
common_voice = load_dataset("facebook/voxpopuli", "en", split='train')
|
| 87 |
+
print(f"Total VoxPopuli {COMMON_VOICE_LANGUAGE} samples: {len(common_voice)}")
|
| 88 |
+
|
| 89 |
+
print("Initializing CosyVoice2 model...")
|
| 90 |
+
cosyvoice = CosyVoice2(
|
| 91 |
+
'/root/autodl-tmp/CosyVoice/pretrained_models/CosyVoice2-0.5B', # 替换为实际模型路径
|
| 92 |
+
load_jit=True,
|
| 93 |
+
load_trt=False,
|
| 94 |
+
fp16=False
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
print("Loading Commonsense QA dataset...")
|
| 98 |
+
dataset = load_dataset("tau/commonsense_qa")
|
| 99 |
+
# 如果只想处理 train,可写成 dataset = load_dataset("tau/commonsense_qa", split="train")
|
| 100 |
+
|
| 101 |
+
# 创建输出目录
|
| 102 |
+
os.makedirs(OUTPUT_DATASET_PATH, exist_ok=True)
|
| 103 |
+
|
| 104 |
+
# ------------------------
|
| 105 |
+
# 主处理循环
|
| 106 |
+
# ------------------------
|
| 107 |
+
final_dataset_dict = {} # 存放各 split 最终处理后的数据
|
| 108 |
+
|
| 109 |
+
for split_name, split_dataset in dataset.items():
|
| 110 |
+
print(f"Processing split: {split_name} with {len(split_dataset)} examples")
|
| 111 |
+
split_output_dir = os.path.join(OUTPUT_DATASET_PATH, split_name)
|
| 112 |
+
os.makedirs(split_output_dir, exist_ok=True)
|
| 113 |
+
|
| 114 |
+
# 用于断点续跑的进度记录
|
| 115 |
+
progress_file = os.path.join(split_output_dir, "progress.txt")
|
| 116 |
+
start_index = 0
|
| 117 |
+
if os.path.exists(progress_file):
|
| 118 |
+
try:
|
| 119 |
+
with open(progress_file, "r") as f:
|
| 120 |
+
start_index = int(f.read().strip())
|
| 121 |
+
print(f"Resuming split '{split_name}' from sample index {start_index}")
|
| 122 |
+
except Exception as e:
|
| 123 |
+
print(f"读取进度文件失败:{e}")
|
| 124 |
+
|
| 125 |
+
final_samples = []
|
| 126 |
+
|
| 127 |
+
# 遍历处理每条样本
|
| 128 |
+
for i in tqdm(range(len(split_dataset)), desc=f"Processing {split_name}"):
|
| 129 |
+
# 如果已处理过,就直接跳过并仅把已存在文件的元信息记入 final_samples
|
| 130 |
+
if i < start_index:
|
| 131 |
+
sample = split_dataset[i]
|
| 132 |
+
wav_path = os.path.join(split_output_dir, f"{i}.wav")
|
| 133 |
+
if os.path.exists(wav_path):
|
| 134 |
+
# 保留所有原始字段 + 音频路径
|
| 135 |
+
sample_dict = {k: sample[k] for k in sample.keys()}
|
| 136 |
+
sample_dict["audio_filepath"] = wav_path
|
| 137 |
+
final_samples.append(sample_dict)
|
| 138 |
+
continue
|
| 139 |
+
|
| 140 |
+
sample = split_dataset[i]
|
| 141 |
+
result = process_example(sample, cosyvoice, common_voice, sample_rate=SAMPLE_RATE)
|
| 142 |
+
|
| 143 |
+
if result is not None:
|
| 144 |
+
audio_tensor = result['audio_tensor']
|
| 145 |
+
if audio_tensor.dim() == 1:
|
| 146 |
+
audio_tensor = audio_tensor.unsqueeze(0)
|
| 147 |
+
sample_rate_val = result['sample_rate']
|
| 148 |
+
|
| 149 |
+
output_wav_path = os.path.join(split_output_dir, f"{i}.wav")
|
| 150 |
+
try:
|
| 151 |
+
torchaudio.save(output_wav_path, audio_tensor, sample_rate_val)
|
| 152 |
+
except Exception as e:
|
| 153 |
+
print(f"Failed to save wav for sample {i}: {e}")
|
| 154 |
+
continue
|
| 155 |
+
|
| 156 |
+
# 保留所有原始字段 + 生成的音频路径
|
| 157 |
+
sample_dict = {k: sample[k] for k in sample.keys()}
|
| 158 |
+
sample_dict["audio_filepath"] = output_wav_path
|
| 159 |
+
final_samples.append(sample_dict)
|
| 160 |
+
else:
|
| 161 |
+
print(f"Sample {i} processing failed, no audio generated.")
|
| 162 |
+
|
| 163 |
+
# 更新进度记录
|
| 164 |
+
with open(progress_file, "w") as f:
|
| 165 |
+
f.write(str(i + 1))
|
| 166 |
+
|
| 167 |
+
# 生成 Hugging Face Dataset 并落盘
|
| 168 |
+
final_dataset_obj = Dataset.from_list(final_samples)
|
| 169 |
+
final_dataset_save_path = os.path.join(split_output_dir, "final_dataset")
|
| 170 |
+
final_dataset_obj.save_to_disk(final_dataset_save_path)
|
| 171 |
+
|
| 172 |
+
print(f"Finished processing split: {split_name} with {len(final_samples)} final samples.")
|
| 173 |
+
final_dataset_dict[split_name] = final_dataset_obj
|
| 174 |
+
|
| 175 |
+
print("所有分割处理完毕,最终数据集已保存。")
|
r1-a/dataset/examqa.py
ADDED
|
@@ -0,0 +1,440 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --- SET CUDA DEVICE ---
|
| 2 |
+
# Method 1: Set environment variable BEFORE importing torch/cosyvoice
|
| 3 |
+
# This makes only GPU 1 visible to the script, appearing as 'cuda:0' internally.
|
| 4 |
+
import os
|
| 5 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "1" # <-- Keep your original setting
|
| 6 |
+
# --- End CUDA Device Setting ---
|
| 7 |
+
|
| 8 |
+
import random
|
| 9 |
+
import torch
|
| 10 |
+
import torchaudio
|
| 11 |
+
# Import load_from_disk to load the dataset saved by your LLM script
|
| 12 |
+
from datasets import load_dataset, Dataset, load_from_disk, Features, Value, Sequence, ClassLabel # Added Features etc. for robustness
|
| 13 |
+
import sys
|
| 14 |
+
from tqdm import tqdm
|
| 15 |
+
import time
|
| 16 |
+
import logging # Add logging
|
| 17 |
+
import json # For fallback saving
|
| 18 |
+
|
| 19 |
+
# Check if the specified GPU is available after setting the environment variable
|
| 20 |
+
if not torch.cuda.is_available():
|
| 21 |
+
print("ERROR: CUDA is not available after setting CUDA_VISIBLE_DEVICES. Cannot run TTS on GPU.")
|
| 22 |
+
print("Check your PyTorch installation, GPU drivers, and CUDA setup.")
|
| 23 |
+
sys.exit(1) # Exit if GPU is required and not found
|
| 24 |
+
else:
|
| 25 |
+
# Since CUDA_VISIBLE_DEVICES is set, the first *visible* device is cuda:0
|
| 26 |
+
effective_device = torch.device("cuda:0")
|
| 27 |
+
print(f"CUDA device visible to PyTorch: {torch.cuda.get_device_name(0)}")
|
| 28 |
+
print(f"Script will effectively run TTS inference on: {effective_device}")
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
sys.path.append('/root/autodl-tmp/CosyVoice') # Make sure this path is correct
|
| 32 |
+
# Import CosyVoice *after* setting the environment variable
|
| 33 |
+
try:
|
| 34 |
+
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
|
| 35 |
+
from cosyvoice.utils.file_utils import load_wav
|
| 36 |
+
except ImportError as e:
|
| 37 |
+
print(f"Error importing CosyVoice: {e}")
|
| 38 |
+
print("Please ensure the path '/root/autodl-tmp/CosyVoice' is correct and the library is installed.")
|
| 39 |
+
sys.exit(1)
|
| 40 |
+
|
| 41 |
+
# Setup basic logging
|
| 42 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 43 |
+
|
| 44 |
+
# ------------------------
|
| 45 |
+
# 配置参数
|
| 46 |
+
# ------------------------
|
| 47 |
+
COMMON_VOICE_LANGUAGE = "en"
|
| 48 |
+
# --- Path to the dataset output by the LLM rephrasing script ---
|
| 49 |
+
REPHRASED_DATASET_PATH = './Multi-subject-RLVR_rephrased/train_processed_final' # <-- ADJUST IF YOUR PATH IS DIFFERENT
|
| 50 |
+
# --- Output path for THIS TTS script ---
|
| 51 |
+
TTS_OUTPUT_PATH = './Multi-subject-RLVR_rephrased_with_audio' # <-- New path for results
|
| 52 |
+
SAMPLE_RATE = 16000
|
| 53 |
+
MAX_TTS_RETRIES = 3
|
| 54 |
+
RETRY_DELAY_SECONDS = 2
|
| 55 |
+
# Define the assumed split name for directory structure (even if only one split)
|
| 56 |
+
ASSUMED_INPUT_SPLIT = "train"
|
| 57 |
+
|
| 58 |
+
# ------------------------
|
| 59 |
+
# 辅助函数 (No changes needed here, includes retry and uses visible GPU)
|
| 60 |
+
# ------------------------
|
| 61 |
+
def get_random_prompt(common_voice_dataset, sample_rate=SAMPLE_RATE):
|
| 62 |
+
"""
|
| 63 |
+
从 VoxPopuli (替代原 Common Voice) 数据集中随机抽取一条语音及对应文本作为 prompt。
|
| 64 |
+
"""
|
| 65 |
+
idx = random.randint(0, len(common_voice_dataset) - 1)
|
| 66 |
+
sample = common_voice_dataset.select([idx])[0]
|
| 67 |
+
audio = sample['audio']
|
| 68 |
+
waveform = torch.tensor(audio['array'], dtype=torch.float32) # Created on CPU
|
| 69 |
+
sr = audio['sampling_rate']
|
| 70 |
+
if sr != sample_rate:
|
| 71 |
+
if waveform.dim() > 1:
|
| 72 |
+
waveform = waveform.mean(dim=0)
|
| 73 |
+
resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate)
|
| 74 |
+
waveform = resampler(waveform)
|
| 75 |
+
if waveform.dim() == 1:
|
| 76 |
+
waveform = waveform.unsqueeze(0)
|
| 77 |
+
if waveform.numel() == 0 or not sample['raw_text']:
|
| 78 |
+
logging.warning("Got an empty prompt, trying again...")
|
| 79 |
+
return get_random_prompt(common_voice_dataset, sample_rate)
|
| 80 |
+
# Return CPU tensor, CosyVoice inference should handle moving it
|
| 81 |
+
return waveform, sample['raw_text']
|
| 82 |
+
|
| 83 |
+
def text_to_audio(query_text, cosyvoice, common_voice_dataset, stream=False, max_retries=MAX_TTS_RETRIES):
|
| 84 |
+
"""
|
| 85 |
+
利用 CosyVoice2 模型将输入文本转换为语音,采用随机 prompt 进行零样本推理。
|
| 86 |
+
Includes retry logic on failure. Assumes cosyvoice runs on the configured device.
|
| 87 |
+
"""
|
| 88 |
+
last_exception = None
|
| 89 |
+
for attempt in range(max_retries):
|
| 90 |
+
try:
|
| 91 |
+
prompt_speech, prompt_text = get_random_prompt(common_voice_dataset, sample_rate=SAMPLE_RATE)
|
| 92 |
+
|
| 93 |
+
all_speech = []
|
| 94 |
+
inference_generator = cosyvoice.inference_zero_shot(
|
| 95 |
+
query_text,
|
| 96 |
+
prompt_text,
|
| 97 |
+
prompt_speech, # Pass CPU tensor
|
| 98 |
+
stream=stream,
|
| 99 |
+
text_frontend=False
|
| 100 |
+
)
|
| 101 |
+
for i, chunk in enumerate(inference_generator):
|
| 102 |
+
if 'tts_speech' in chunk and chunk['tts_speech'] is not None:
|
| 103 |
+
all_speech.append(chunk['tts_speech'])
|
| 104 |
+
else:
|
| 105 |
+
logging.warning(f"TTS Chunk {i} missing 'tts_speech' for text '{query_text[:60]}...'")
|
| 106 |
+
|
| 107 |
+
if not all_speech:
|
| 108 |
+
if torch.cuda.is_available(): torch.cuda.empty_cache()
|
| 109 |
+
raise ValueError("TTS inference finished but produced no audio chunks.")
|
| 110 |
+
|
| 111 |
+
combined_speech = torch.cat(all_speech, dim=-1) # On GPU
|
| 112 |
+
sample_rate_val = cosyvoice.sample_rate
|
| 113 |
+
|
| 114 |
+
return {
|
| 115 |
+
'audio_tensor': combined_speech, # Return GPU tensor
|
| 116 |
+
'sample_rate': sample_rate_val
|
| 117 |
+
}
|
| 118 |
+
except Exception as e:
|
| 119 |
+
last_exception = e
|
| 120 |
+
logging.error(f"Error converting text to audio on attempt {attempt + 1}/{max_retries}: {e}", exc_info=True)
|
| 121 |
+
logging.error(f"Failed Text: '{query_text[:100]}...'")
|
| 122 |
+
logging.error(f"Prompt Text Used: '{prompt_text[:100]}...'")
|
| 123 |
+
if torch.cuda.is_available(): torch.cuda.empty_cache()
|
| 124 |
+
if attempt < max_retries - 1:
|
| 125 |
+
wait_time = RETRY_DELAY_SECONDS * (1.5 ** attempt) + random.uniform(0.5, 1.5)
|
| 126 |
+
logging.warning(f"Retrying TTS with a different prompt in {wait_time:.2f}s...")
|
| 127 |
+
time.sleep(wait_time)
|
| 128 |
+
else:
|
| 129 |
+
logging.error(f"All {max_retries} TTS attempts failed.")
|
| 130 |
+
|
| 131 |
+
logging.error(f"Failed to generate audio for text after {max_retries} attempts: '{query_text[:60]}...'")
|
| 132 |
+
logging.error(f"Last TTS error: {last_exception}")
|
| 133 |
+
return None
|
| 134 |
+
|
| 135 |
+
def process_example(example, cosyvoice, common_voice_dataset, sample_rate=SAMPLE_RATE):
|
| 136 |
+
"""
|
| 137 |
+
针对从磁盘加载的 LLM rephrased 数据集中的单个样本进行 TTS 处理。
|
| 138 |
+
Processes example['query_rephrased']. <--- Target the rephrased query
|
| 139 |
+
"""
|
| 140 |
+
# --- Target the 'query_rephrased' field from the LLM output dataset ---
|
| 141 |
+
text_to_convert = example.get('query_rephrased') # <--- Use 'query_rephrased' field
|
| 142 |
+
if not text_to_convert or not isinstance(text_to_convert, str) or text_to_convert.strip() == "":
|
| 143 |
+
original_query = example.get('query', [{}])[0].get('content', 'Original Query Missing')[:50]
|
| 144 |
+
logging.warning(f"Skipping TTS for example due to missing or empty 'query_rephrased' field. Original query started with: '{original_query}...'. Status was: {example.get('query_rephrased_status','N/A')}")
|
| 145 |
+
return None
|
| 146 |
+
|
| 147 |
+
# --- Use the text_to_audio function with retry logic ---
|
| 148 |
+
audio_result = text_to_audio(text_to_convert, cosyvoice, common_voice_dataset, stream=False)
|
| 149 |
+
|
| 150 |
+
if audio_result is not None:
|
| 151 |
+
audio_tensor = audio_result['audio_tensor'] # Still on GPU here
|
| 152 |
+
if audio_tensor.dim() == 1:
|
| 153 |
+
audio_tensor = audio_tensor.unsqueeze(0)
|
| 154 |
+
elif audio_tensor.dim() > 2:
|
| 155 |
+
logging.warning(f"Generated audio tensor has unexpected shape {audio_tensor.shape}. Attempting to flatten.")
|
| 156 |
+
audio_tensor = audio_tensor.view(1, -1)
|
| 157 |
+
|
| 158 |
+
if audio_tensor.numel() == 0:
|
| 159 |
+
logging.warning(f"Generated audio tensor is empty for rephrased query: '{text_to_convert[:60]}...'")
|
| 160 |
+
return None
|
| 161 |
+
|
| 162 |
+
return {
|
| 163 |
+
'audio_tensor': audio_tensor, # Return GPU tensor
|
| 164 |
+
'sample_rate': audio_result['sample_rate']
|
| 165 |
+
}
|
| 166 |
+
else:
|
| 167 |
+
# text_to_audio already logged the failure
|
| 168 |
+
return None
|
| 169 |
+
|
| 170 |
+
# ------------------------
|
| 171 |
+
# 数据加载与模型初始化
|
| 172 |
+
# ------------------------
|
| 173 |
+
logging.info("Loading VoxPopuli (as Common Voice) dataset for prompts...")
|
| 174 |
+
try:
|
| 175 |
+
common_voice = load_dataset("facebook/voxpopuli", "en", split='train')
|
| 176 |
+
logging.info(f"Total VoxPopuli {COMMON_VOICE_LANGUAGE} samples: {len(common_voice)}")
|
| 177 |
+
if len(common_voice) == 0:
|
| 178 |
+
raise ValueError("VoxPopuli dataset loaded but contains no samples.")
|
| 179 |
+
except Exception as e:
|
| 180 |
+
logging.error(f"Error loading VoxPopuli dataset: {e}", exc_info=True)
|
| 181 |
+
sys.exit(1)
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
logging.info("Initializing CosyVoice2 model...")
|
| 185 |
+
try:
|
| 186 |
+
cosyvoice = CosyVoice2(
|
| 187 |
+
'/root/autodl-tmp/CosyVoice/pretrained_models/CosyVoice2-0.5B',
|
| 188 |
+
load_jit=True,
|
| 189 |
+
load_trt=False,
|
| 190 |
+
fp16=False # Consider setting to True if VRAM is an issue and you have FP16 support
|
| 191 |
+
)
|
| 192 |
+
logging.info(f"CosyVoice model initialized on effective device: {effective_device}")
|
| 193 |
+
except Exception as e:
|
| 194 |
+
logging.error(f"Error initializing CosyVoice2 model: {e}", exc_info=True)
|
| 195 |
+
sys.exit(1)
|
| 196 |
+
|
| 197 |
+
logging.info(f"Loading rephrased dataset from disk: {REPHRASED_DATASET_PATH}")
|
| 198 |
+
try:
|
| 199 |
+
# --- Load the single dataset saved by the LLM script ---
|
| 200 |
+
rephrased_dataset = load_from_disk(REPHRASED_DATASET_PATH)
|
| 201 |
+
if not rephrased_dataset:
|
| 202 |
+
raise ValueError(f"Dataset loaded from '{REPHRASED_DATASET_PATH}' is empty or invalid.")
|
| 203 |
+
# --- Wrap it in a dict to match the loop structure expecting splits ---
|
| 204 |
+
# Use the assumed split name as the key
|
| 205 |
+
dataset_dict = {ASSUMED_INPUT_SPLIT: rephrased_dataset}
|
| 206 |
+
logging.info(f"Successfully loaded dataset with {len(rephrased_dataset)} examples.")
|
| 207 |
+
except FileNotFoundError:
|
| 208 |
+
logging.error(f"Error: Rephrased dataset not found at '{REPHRASED_DATASET_PATH}'.")
|
| 209 |
+
logging.error("Please ensure the LLM rephrasing script ran successfully and saved data to the correct location.")
|
| 210 |
+
sys.exit(1)
|
| 211 |
+
except Exception as e:
|
| 212 |
+
logging.error(f"Error loading rephrased dataset from '{REPHRASED_DATASET_PATH}': {e}", exc_info=True)
|
| 213 |
+
sys.exit(1)
|
| 214 |
+
|
| 215 |
+
# 创建输出目录
|
| 216 |
+
os.makedirs(TTS_OUTPUT_PATH, exist_ok=True)
|
| 217 |
+
|
| 218 |
+
# ------------------------
|
| 219 |
+
# 主处理循环
|
| 220 |
+
# ------------------------
|
| 221 |
+
final_dataset_dict_for_tracking = {} # To track which final datasets were saved
|
| 222 |
+
|
| 223 |
+
# Iterate through the dictionary (will contain only one split, e.g., 'train')
|
| 224 |
+
for split_name, split_dataset in dataset_dict.items():
|
| 225 |
+
logging.info(f"\nProcessing split: {split_name} with {len(split_dataset)} examples for TTS")
|
| 226 |
+
# Output directory for *this* script's results (audio + final dataset)
|
| 227 |
+
split_output_dir = os.path.join(TTS_OUTPUT_PATH, split_name)
|
| 228 |
+
os.makedirs(split_output_dir, exist_ok=True)
|
| 229 |
+
logging.info(f"Audio files and final data for this split will be saved in: {split_output_dir}")
|
| 230 |
+
|
| 231 |
+
# 用于断点续跑的进度记录 (specific to this TTS process)
|
| 232 |
+
progress_file = os.path.join(split_output_dir, "tts_progress.txt") # Use specific name
|
| 233 |
+
start_index = 0
|
| 234 |
+
if os.path.exists(progress_file):
|
| 235 |
+
try:
|
| 236 |
+
with open(progress_file, "r") as f:
|
| 237 |
+
content = f.read().strip()
|
| 238 |
+
if content:
|
| 239 |
+
start_index = int(content)
|
| 240 |
+
logging.info(f"Resuming split '{split_name}' TTS from sample index {start_index}")
|
| 241 |
+
else:
|
| 242 |
+
logging.info(f"Progress file '{progress_file}' is empty, starting TTS from index 0.")
|
| 243 |
+
start_index = 0
|
| 244 |
+
except ValueError:
|
| 245 |
+
logging.warning(f"Could not parse integer from progress file '{progress_file}'. Starting TTS from index 0.")
|
| 246 |
+
start_index = 0
|
| 247 |
+
except Exception as e:
|
| 248 |
+
logging.error(f"Error reading progress file '{progress_file}': {e}. Starting TTS from index 0.")
|
| 249 |
+
start_index = 0
|
| 250 |
+
else:
|
| 251 |
+
logging.info(f"No progress file found at '{progress_file}'. Starting TTS from index 0.")
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
# --- [NEW] Section: Check and Save Already Completed Samples ---
|
| 255 |
+
already_processed_samples = []
|
| 256 |
+
if start_index > 0:
|
| 257 |
+
logging.info(f"Checking for already processed samples (audio files) up to index {start_index - 1}...")
|
| 258 |
+
# Use tqdm here for visibility if start_index is large
|
| 259 |
+
for j in tqdm(range(start_index), desc="Checking existing audio"):
|
| 260 |
+
potential_output_wav_path = os.path.join(split_output_dir, f"{j}.wav")
|
| 261 |
+
if os.path.exists(potential_output_wav_path):
|
| 262 |
+
try:
|
| 263 |
+
# Ensure index is valid before accessing
|
| 264 |
+
if j < len(split_dataset):
|
| 265 |
+
original_sample = split_dataset[j]
|
| 266 |
+
# Create a dict with all original keys + the existing audio path
|
| 267 |
+
completed_sample_dict = {k: original_sample[k] for k in original_sample.keys()}
|
| 268 |
+
completed_sample_dict["audio_filepath"] = potential_output_wav_path # Point to the existing audio
|
| 269 |
+
already_processed_samples.append(completed_sample_dict)
|
| 270 |
+
else:
|
| 271 |
+
logging.warning(f"Index {j} is out of bounds for the loaded dataset (size {len(split_dataset)}) while checking existing files. Skipping this index.")
|
| 272 |
+
except Exception as e:
|
| 273 |
+
logging.error(f"Error processing data for existing sample index {j}: {e}")
|
| 274 |
+
|
| 275 |
+
if already_processed_samples:
|
| 276 |
+
logging.info(f"Found {len(already_processed_samples)} samples with existing audio files before the resume point.")
|
| 277 |
+
# Define path for the dataset of already processed samples
|
| 278 |
+
already_processed_dataset_path = os.path.join(split_output_dir, "already_processed_dataset")
|
| 279 |
+
try:
|
| 280 |
+
logging.info(f"Saving these {len(already_processed_samples)} already processed samples to: {already_processed_dataset_path}")
|
| 281 |
+
|
| 282 |
+
# Define features based on the original dataset + the new audio_filepath column
|
| 283 |
+
original_features = split_dataset.features
|
| 284 |
+
new_features_dict = original_features.copy()
|
| 285 |
+
if "audio_filepath" not in new_features_dict:
|
| 286 |
+
new_features_dict["audio_filepath"] = Value('string')
|
| 287 |
+
new_features = Features(new_features_dict)
|
| 288 |
+
|
| 289 |
+
already_processed_dataset = Dataset.from_list(already_processed_samples, features=new_features)
|
| 290 |
+
already_processed_dataset.save_to_disk(already_processed_dataset_path)
|
| 291 |
+
logging.info(f"Successfully saved dataset of {len(already_processed_samples)} already processed samples.")
|
| 292 |
+
# Clear the list to free memory
|
| 293 |
+
del already_processed_samples
|
| 294 |
+
if torch.cuda.is_available(): torch.cuda.empty_cache() # Clear cache after potential large list processing
|
| 295 |
+
except Exception as e:
|
| 296 |
+
logging.error(f"Failed to create or save the dataset of already processed samples: {e}", exc_info=True)
|
| 297 |
+
# Keep already_processed_samples list in memory in case of save failure? Maybe not needed.
|
| 298 |
+
else:
|
| 299 |
+
logging.info("No existing audio files found for samples before the resume point (index 0 to {}).".format(start_index - 1))
|
| 300 |
+
# --- [END NEW] Section ---
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
# --- Main processing loop ---
|
| 304 |
+
final_samples = [] # List to hold ALL processed sample dictionaries for the FINAL dataset of this run
|
| 305 |
+
logging.info(f"Starting/Resuming TTS processing from index {start_index}...")
|
| 306 |
+
pbar = tqdm(range(start_index, len(split_dataset)), desc=f"TTS on '{split_name}' query_rephrased", initial=start_index, total=len(split_dataset))
|
| 307 |
+
for i in pbar:
|
| 308 |
+
try:
|
| 309 |
+
sample = split_dataset[i] # Sample data is on CPU
|
| 310 |
+
except IndexError:
|
| 311 |
+
logging.error(f"Index {i} is out of bounds for split_dataset (size {len(split_dataset)}). Stopping processing.")
|
| 312 |
+
break # Stop if we somehow go out of bounds
|
| 313 |
+
|
| 314 |
+
# Define path for the *new* audio file (or potentially existing one)
|
| 315 |
+
output_wav_path = os.path.join(split_output_dir, f"{i}.wav")
|
| 316 |
+
|
| 317 |
+
# Check if this specific TTS output file ALREADY exists
|
| 318 |
+
# This handles cases where the script stopped AFTER saving audio but BEFORE updating progress
|
| 319 |
+
# OR if files were somehow generated but progress file was lost/reset.
|
| 320 |
+
if os.path.exists(output_wav_path):
|
| 321 |
+
logging.debug(f"Audio file already exists for index {i} at {output_wav_path}. Skipping TTS, adding to final list.")
|
| 322 |
+
# Create the dict with all original keys + the existing audio path
|
| 323 |
+
sample_dict = {k: sample[k] for k in sample.keys()}
|
| 324 |
+
sample_dict["audio_filepath"] = output_wav_path # Point to the existing audio
|
| 325 |
+
final_samples.append(sample_dict) # Add to the list for the *final* dataset
|
| 326 |
+
# Update progress even if skipped due to existing file (important!)
|
| 327 |
+
with open(progress_file, "w") as f:
|
| 328 |
+
f.write(str(i + 1))
|
| 329 |
+
continue # Move to the next sample
|
| 330 |
+
|
| 331 |
+
# --- Perform TTS on the 'query_rephrased' field (UNCHANGED CORE LOGIC) ---
|
| 332 |
+
result = process_example(sample, cosyvoice, common_voice, sample_rate=SAMPLE_RATE)
|
| 333 |
+
|
| 334 |
+
if result is not None:
|
| 335 |
+
audio_tensor = result['audio_tensor'] # Received tensor is on GPU
|
| 336 |
+
sample_rate_val = result['sample_rate']
|
| 337 |
+
|
| 338 |
+
try:
|
| 339 |
+
# --- Move tensor to CPU before saving ---
|
| 340 |
+
audio_tensor_save = audio_tensor.detach().cpu().to(torch.float32)
|
| 341 |
+
if audio_tensor_save.dim() == 1:
|
| 342 |
+
audio_tensor_save = audio_tensor_save.unsqueeze(0)
|
| 343 |
+
elif audio_tensor_save.dim() > 2:
|
| 344 |
+
audio_tensor_save = audio_tensor_save.view(1, -1)
|
| 345 |
+
|
| 346 |
+
torchaudio.save(output_wav_path, audio_tensor_save, sample_rate_val)
|
| 347 |
+
|
| 348 |
+
# --- Preserve all original fields + add the NEW audio path ---
|
| 349 |
+
sample_dict = {k: sample[k] for k in sample.keys()}
|
| 350 |
+
sample_dict["audio_filepath"] = output_wav_path # Add the path to the new audio
|
| 351 |
+
final_samples.append(sample_dict) # Add to the list for the *final* dataset
|
| 352 |
+
|
| 353 |
+
# Explicitly delete GPU tensor and clear cache
|
| 354 |
+
del audio_tensor
|
| 355 |
+
del audio_tensor_save
|
| 356 |
+
if torch.cuda.is_available(): torch.cuda.empty_cache()
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
except Exception as e:
|
| 360 |
+
logging.error(f"Failed to save wav for sample {i} (TTS of query_rephrased) at {output_wav_path}: {e}", exc_info=True)
|
| 361 |
+
if torch.cuda.is_available(): torch.cuda.empty_cache() # Clear cache on save error too
|
| 362 |
+
else:
|
| 363 |
+
# process_example already logged the failure
|
| 364 |
+
logging.warning(f"Sample {i} TTS failed after retries (Rephrased Query: '{str(sample.get('query_rephrased', 'N/A'))[:60]}...'), no audio generated. This sample will NOT be included in the final dataset.")
|
| 365 |
+
# Decide whether to add sample without audio path or skip it
|
| 366 |
+
# Skipping for now, as audio is the goal. If you wanted to include failed ones:
|
| 367 |
+
# sample_dict = {k: sample[k] for k in sample.keys()}
|
| 368 |
+
# sample_dict["audio_filepath"] = None # Indicate missing audio
|
| 369 |
+
# final_samples.append(sample_dict)
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
# Update progress file after processing each sample (success or failure to ensure resume point advances)
|
| 373 |
+
# Make sure this is outside the 'if result is not None' block
|
| 374 |
+
with open(progress_file, "w") as f:
|
| 375 |
+
f.write(str(i + 1))
|
| 376 |
+
|
| 377 |
+
# Optional: More frequent cache clearing (Uncomment if needed)
|
| 378 |
+
# if i % 50 == 0 and torch.cuda.is_available():
|
| 379 |
+
# torch.cuda.empty_cache()
|
| 380 |
+
# logging.debug(f"Cleared CUDA cache at index {i}")
|
| 381 |
+
|
| 382 |
+
# --- Final cache clear after finishing the split ---
|
| 383 |
+
if torch.cuda.is_available():
|
| 384 |
+
logging.info("Clearing final CUDA cache for the split.")
|
| 385 |
+
torch.cuda.empty_cache()
|
| 386 |
+
|
| 387 |
+
# --- Save the final dataset object for this split (contains items found + newly generated) ---
|
| 388 |
+
if final_samples:
|
| 389 |
+
# Define features based on the original dataset + the new audio_filepath column for the final dataset
|
| 390 |
+
final_features_dict = split_dataset.features.copy()
|
| 391 |
+
if "audio_filepath" not in final_features_dict:
|
| 392 |
+
final_features_dict["audio_filepath"] = Value('string')
|
| 393 |
+
final_features = Features(final_features_dict)
|
| 394 |
+
|
| 395 |
+
try:
|
| 396 |
+
logging.info(f"Attempting to create final dataset object from {len(final_samples)} collected samples...")
|
| 397 |
+
final_dataset_obj = Dataset.from_list(final_samples, features=final_features)
|
| 398 |
+
|
| 399 |
+
# Save the final dataset object inside the split's output directory
|
| 400 |
+
final_dataset_save_path = os.path.join(split_output_dir, "final_dataset") # Name for the complete dataset
|
| 401 |
+
|
| 402 |
+
logging.info(f"Saving final dataset object for split '{split_name}' (with audio paths) to {final_dataset_save_path}...")
|
| 403 |
+
os.makedirs(os.path.dirname(final_dataset_save_path), exist_ok=True) # Should exist, but safety check
|
| 404 |
+
final_dataset_obj.save_to_disk(final_dataset_save_path)
|
| 405 |
+
logging.info(f"Finished processing split: {split_name}. Saved {len(final_samples)} samples in the final dataset at '{final_dataset_save_path}'.")
|
| 406 |
+
final_dataset_dict_for_tracking[split_name] = final_dataset_obj # Keep track if needed
|
| 407 |
+
|
| 408 |
+
except Exception as e:
|
| 409 |
+
logging.error(f"Error creating or saving final dataset object for split '{split_name}': {e}", exc_info=True)
|
| 410 |
+
logging.error("Attempting to save final_samples list as JSON Lines as a fallback...")
|
| 411 |
+
fallback_path = os.path.join(split_output_dir, "final_samples_fallback.jsonl") # Use distinct fallback name
|
| 412 |
+
try:
|
| 413 |
+
with open(fallback_path, 'w', encoding='utf-8') as f:
|
| 414 |
+
for item in final_samples:
|
| 415 |
+
# Basic serialization attempt
|
| 416 |
+
serializable_item = {}
|
| 417 |
+
for k, v in item.items():
|
| 418 |
+
if isinstance(v, torch.Tensor):
|
| 419 |
+
serializable_item[k] = f"Tensor data (shape: {v.shape})" # Placeholder
|
| 420 |
+
elif isinstance(v, (dict, list, str, int, float, bool, type(None))):
|
| 421 |
+
serializable_item[k] = v
|
| 422 |
+
else:
|
| 423 |
+
serializable_item[k] = str(v) # Attempt string conversion for others
|
| 424 |
+
f.write(json.dumps(serializable_item) + '\n')
|
| 425 |
+
logging.info(f"Fallback JSON Lines saved to {fallback_path}")
|
| 426 |
+
except Exception as json_e:
|
| 427 |
+
logging.error(f"Fallback JSON save failed: {json_e}", exc_info=True)
|
| 428 |
+
|
| 429 |
+
else:
|
| 430 |
+
logging.warning(f"Finished processing split: {split_name}. No samples were successfully processed or found with existing audio during this run to add to the final dataset.")
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
print("="*30)
|
| 434 |
+
if final_dataset_dict_for_tracking:
|
| 435 |
+
logging.info(f"All specified splits processed for TTS. Final datasets saved in respective 'final_dataset' subdirectories within '{TTS_OUTPUT_PATH}'.")
|
| 436 |
+
logging.info(f"Processed splits where final datasets were generated: {list(final_dataset_dict_for_tracking.keys())}")
|
| 437 |
+
logging.info("Additionally, if resuming, datasets containing only the samples processed *before* this run may have been saved in 'already_processed_dataset' subdirectories.")
|
| 438 |
+
else:
|
| 439 |
+
logging.warning(f"TTS processing finished, but no final datasets were generated or saved in the 'final_dataset' subdirectories within '{TTS_OUTPUT_PATH}'. Check logs for errors. Pre-existing data might be in 'already_processed_dataset' if resuming.")
|
| 440 |
+
print("="*30)
|
r1-a/dataset/examqa_rewrite.py
ADDED
|
@@ -0,0 +1,487 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import http.client
|
| 3 |
+
import json
|
| 4 |
+
import time
|
| 5 |
+
import random
|
| 6 |
+
from datasets import load_dataset, Dataset, DatasetDict, Features, Value
|
| 7 |
+
from tqdm.auto import tqdm
|
| 8 |
+
import sys
|
| 9 |
+
import logging
|
| 10 |
+
import getpass
|
| 11 |
+
import signal
|
| 12 |
+
import socket
|
| 13 |
+
import concurrent.futures
|
| 14 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 15 |
+
import argparse # For command-line arguments
|
| 16 |
+
|
| 17 |
+
# --- Configuration --- (Mostly Same)
|
| 18 |
+
DATASET_NAME = "virtuoussy/Multi-subject-RLVR"
|
| 19 |
+
DATASET_SPLIT = "train"
|
| 20 |
+
API_HOST = "api2.aigcbest.top"
|
| 21 |
+
API_PATH = "/v1/chat/completions"
|
| 22 |
+
LLM_MODEL = "gpt-4.1-mini"
|
| 23 |
+
API_KEY = os.environ.get('AIGCBEST_API_KEY', "sk-U15cDXxI0bboL6iH4Hymzl30ws6oWzazWe1Ndwq9QtiPUEgI") # Simplified API Key Get
|
| 24 |
+
if not API_KEY or API_KEY == "YOUR_API_KEY_HERE":
|
| 25 |
+
print("API Key is not set correctly. Please set the AIGCBEST_API_KEY environment variable or replace the placeholder.")
|
| 26 |
+
sys.exit(1)
|
| 27 |
+
|
| 28 |
+
OUTPUT_DIR = f"./{DATASET_NAME.split('/')[-1]}_rephrased"
|
| 29 |
+
# Define the path where the *potentially incomplete* processed dataset exists
|
| 30 |
+
PROCESSED_DATA_PATH = os.path.join(OUTPUT_DIR, f"{DATASET_SPLIT}_processed")
|
| 31 |
+
# Define where the *final, retried* dataset will be saved
|
| 32 |
+
FINAL_OUTPUT_PATH = os.path.join(OUTPUT_DIR, f"{DATASET_SPLIT}_processed_final") # Save to new location initially
|
| 33 |
+
|
| 34 |
+
BATCH_SAVE_SIZE = 500 # How often to save intermediate progress *during retry*
|
| 35 |
+
MAX_WORKERS = 20
|
| 36 |
+
REQUEST_DELAY_SECONDS = 0.15
|
| 37 |
+
MAX_RETRIES = 3
|
| 38 |
+
|
| 39 |
+
# Setup logging
|
| 40 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 41 |
+
logging.getLogger("datasets").setLevel(logging.WARNING)
|
| 42 |
+
logging.getLogger("huggingface_hub").setLevel(logging.WARNING)
|
| 43 |
+
|
| 44 |
+
# --- LLM API Function (call_llm_api) ---
|
| 45 |
+
# Use the robust version from the previous answer
|
| 46 |
+
def call_llm_api(original_question, api_key, host, path, model, retries=MAX_RETRIES):
|
| 47 |
+
system_prompt = (
|
| 48 |
+
"You are an expert linguist specializing in converting structured prompts or "
|
| 49 |
+
"fill-in-the-blank problems into natural, spoken-language questions suitable for "
|
| 50 |
+
"text-to-speech (TTS). Your goal is to make the question sound like how a person "
|
| 51 |
+
"would naturally ask it. "
|
| 52 |
+
"If the input is a fill-in-the-blank problem (e.g., contains '-----'), "
|
| 53 |
+
"rephrase it as a direct question asking for the missing information. "
|
| 54 |
+
"Keep the core meaning, mathematical context, variables, and numbers exactly the same. "
|
| 55 |
+
"Focus only on rephrasing the *user's question* part provided. "
|
| 56 |
+
"Output *only* the rephrased question, without any introductory phrases like 'Here's the rephrased question:'."
|
| 57 |
+
)
|
| 58 |
+
payload = json.dumps({
|
| 59 |
+
"model": model,
|
| 60 |
+
"messages": [
|
| 61 |
+
{"role": "system", "content": system_prompt},
|
| 62 |
+
{"role": "user", "content": original_question}
|
| 63 |
+
],
|
| 64 |
+
})
|
| 65 |
+
headers = {
|
| 66 |
+
'Accept': 'application/json',
|
| 67 |
+
'Authorization': f'Bearer {api_key}',
|
| 68 |
+
'User-Agent': 'HuggingFace Dataset Processing Script (Retry Mode)',
|
| 69 |
+
'Content-Type': 'application/json'
|
| 70 |
+
}
|
| 71 |
+
time.sleep(random.uniform(REQUEST_DELAY_SECONDS * 0.8, REQUEST_DELAY_SECONDS * 1.2))
|
| 72 |
+
|
| 73 |
+
for attempt in range(retries):
|
| 74 |
+
logging.debug(f"API call attempt {attempt + 1}/{retries} for: {original_question[:50]}...")
|
| 75 |
+
try:
|
| 76 |
+
conn = http.client.HTTPSConnection(host, timeout=60)
|
| 77 |
+
conn.request("POST", path, payload, headers)
|
| 78 |
+
res = conn.getresponse()
|
| 79 |
+
status = res.status
|
| 80 |
+
data = res.read()
|
| 81 |
+
conn.close()
|
| 82 |
+
|
| 83 |
+
if status == 200:
|
| 84 |
+
response_json = json.loads(data.decode("utf-8"))
|
| 85 |
+
if response_json.get("choices") and len(response_json["choices"]) > 0:
|
| 86 |
+
message = response_json["choices"][0].get("message")
|
| 87 |
+
if message and message.get("content"):
|
| 88 |
+
rephrased = message["content"].strip()
|
| 89 |
+
if len(rephrased) > 1 and ((rephrased.startswith('"') and rephrased.endswith('"')) or \
|
| 90 |
+
(rephrased.startswith("'") and rephrased.endswith("'"))):
|
| 91 |
+
rephrased = rephrased[1:-1]
|
| 92 |
+
if rephrased and rephrased.strip().lower() != original_question.strip().lower():
|
| 93 |
+
logging.debug(f"Successfully rephrased: {rephrased[:80]}...")
|
| 94 |
+
return rephrased
|
| 95 |
+
elif not rephrased:
|
| 96 |
+
logging.warning(f"LLM returned empty response for: {original_question[:50]}...")
|
| 97 |
+
return None
|
| 98 |
+
else:
|
| 99 |
+
logging.warning(f"LLM returned identical response for: {original_question[:50]}...")
|
| 100 |
+
return None # Treat identical as failure for rephrasing
|
| 101 |
+
logging.error(f"Unexpected API response structure: {data.decode('utf-8')}")
|
| 102 |
+
return None
|
| 103 |
+
elif status == 429:
|
| 104 |
+
retry_after_header = res.getheader('Retry-After', '5')
|
| 105 |
+
try: wait_time = int(retry_after_header)
|
| 106 |
+
except ValueError: wait_time = REQUEST_DELAY_SECONDS * (2 ** attempt) + random.uniform(1, 5)
|
| 107 |
+
logging.warning(f"Rate limit exceeded (HTTP {status}). Retrying after {wait_time:.2f} seconds...")
|
| 108 |
+
time.sleep(wait_time)
|
| 109 |
+
elif status >= 500:
|
| 110 |
+
wait_time = REQUEST_DELAY_SECONDS * (2 ** attempt) + random.uniform(1, 5)
|
| 111 |
+
logging.warning(f"Server error (HTTP {status}). Retrying after {wait_time:.2f} seconds...")
|
| 112 |
+
time.sleep(wait_time)
|
| 113 |
+
else:
|
| 114 |
+
logging.error(f"API Client Error: Status {status}, Response: {data.decode('utf-8')}")
|
| 115 |
+
return None
|
| 116 |
+
except (http.client.HTTPException, ConnectionError, socket.gaierror, TimeoutError, socket.timeout) as e:
|
| 117 |
+
logging.error(f"Network/HTTP error during API call: {e}. Attempt {attempt + 1}/{retries}")
|
| 118 |
+
if attempt + 1 == retries: return None
|
| 119 |
+
wait_time = REQUEST_DELAY_SECONDS * (1.5 ** attempt) + random.uniform(1, 3)
|
| 120 |
+
logging.warning(f"Waiting {wait_time:.2f} seconds before retry...")
|
| 121 |
+
time.sleep(wait_time)
|
| 122 |
+
except json.JSONDecodeError as e:
|
| 123 |
+
logging.error(f"Failed to decode API response: {e}. Response snippet: {data[:200]}")
|
| 124 |
+
if attempt + 1 == retries: return None
|
| 125 |
+
wait_time = REQUEST_DELAY_SECONDS * (2 ** attempt) + random.uniform(1, 5)
|
| 126 |
+
time.sleep(wait_time) # Wait before next attempt
|
| 127 |
+
except Exception as e:
|
| 128 |
+
logging.error(f"An unexpected error occurred during API call: {e}", exc_info=True)
|
| 129 |
+
if attempt + 1 == retries: return None
|
| 130 |
+
wait_time = REQUEST_DELAY_SECONDS * (1.5 ** attempt) + random.uniform(1, 3)
|
| 131 |
+
logging.warning(f"Waiting {wait_time:.2f} seconds before retry...")
|
| 132 |
+
time.sleep(wait_time)
|
| 133 |
+
|
| 134 |
+
logging.error(f"API call failed after {retries} retries for: {original_question[:50]}...")
|
| 135 |
+
return None
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
# --- Dataset Processing Function (rephrase_query_entry) ---
|
| 139 |
+
# Same as before, returns the full dictionary with status
|
| 140 |
+
def rephrase_query_entry(example):
|
| 141 |
+
processed_example = example.copy()
|
| 142 |
+
# Ensure status field exists, default to unprocessed if missing
|
| 143 |
+
if 'query_rephrased_status' not in processed_example:
|
| 144 |
+
processed_example['query_rephrased_status'] = 'unprocessed'
|
| 145 |
+
|
| 146 |
+
original_query_list = example.get("query")
|
| 147 |
+
|
| 148 |
+
# --- Input Validation ---
|
| 149 |
+
if original_query_list is None:
|
| 150 |
+
processed_example['query_rephrased_status'] = 'skipped_missing_query_column'
|
| 151 |
+
processed_example['query_rephrased'] = None
|
| 152 |
+
return processed_example
|
| 153 |
+
if not isinstance(original_query_list, list):
|
| 154 |
+
processed_example['query_rephrased_status'] = 'skipped_query_not_list'
|
| 155 |
+
processed_example['query_rephrased'] = None
|
| 156 |
+
return processed_example
|
| 157 |
+
if not original_query_list:
|
| 158 |
+
processed_example['query_rephrased_status'] = 'skipped_query_list_empty'
|
| 159 |
+
processed_example['query_rephrased'] = None
|
| 160 |
+
return processed_example
|
| 161 |
+
|
| 162 |
+
# --- Find User Question ---
|
| 163 |
+
user_question = None
|
| 164 |
+
for i, message in enumerate(original_query_list):
|
| 165 |
+
if isinstance(message, dict) and message.get("role") == "user":
|
| 166 |
+
content = message.get("content")
|
| 167 |
+
if isinstance(content, str) and content.strip():
|
| 168 |
+
user_question = content
|
| 169 |
+
break
|
| 170 |
+
else:
|
| 171 |
+
processed_example['query_rephrased_status'] = 'skipped_invalid_user_content'
|
| 172 |
+
processed_example['query_rephrased'] = None
|
| 173 |
+
return processed_example
|
| 174 |
+
|
| 175 |
+
if not user_question:
|
| 176 |
+
processed_example['query_rephrased_status'] = 'skipped_no_user_content_found'
|
| 177 |
+
processed_example['query_rephrased'] = None
|
| 178 |
+
return processed_example
|
| 179 |
+
|
| 180 |
+
# --- Call LLM API ---
|
| 181 |
+
logging.info(f"Attempting to rephrase: {user_question[:60]}...") # Log retry attempt
|
| 182 |
+
rephrased_query_content = call_llm_api(user_question, API_KEY, API_HOST, API_PATH, LLM_MODEL)
|
| 183 |
+
|
| 184 |
+
# --- Update Example Based on API Result ---
|
| 185 |
+
if rephrased_query_content:
|
| 186 |
+
logging.debug(f"Rephrased '{user_question[:30]}...' to '{rephrased_query_content[:30]}...'")
|
| 187 |
+
processed_example["query_rephrased"] = rephrased_query_content
|
| 188 |
+
processed_example['query_rephrased_status'] = 'success_retried' # New status for successful retry
|
| 189 |
+
else:
|
| 190 |
+
logging.warning(f"Retry failed for user question: {user_question[:50]}...")
|
| 191 |
+
# Keep existing rephrased content (likely None) but update status
|
| 192 |
+
processed_example['query_rephrased_status'] = 'failed_llm_retry' # New status for failed retry
|
| 193 |
+
|
| 194 |
+
return processed_example
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
# --- Function to Save Progress ---
|
| 198 |
+
# Saves the *entire list* of dictionaries
|
| 199 |
+
def save_final_dataset(data_list, output_path):
|
| 200 |
+
"""Saves the final list of processed data dictionaries."""
|
| 201 |
+
if not data_list:
|
| 202 |
+
logging.info("No data provided for saving.")
|
| 203 |
+
return False
|
| 204 |
+
logging.info(f"Attempting to save {len(data_list)} final examples to {output_path}...")
|
| 205 |
+
try:
|
| 206 |
+
# Define features explicitly to handle potential Nones and ensure consistency
|
| 207 |
+
# Adjust types based on your actual dataset structure
|
| 208 |
+
features = Features({
|
| 209 |
+
'query': [{'role': Value(dtype='string', id=None), 'content': Value(dtype='string', id=None)}],
|
| 210 |
+
'query_rephrased': Value(dtype='string', id=None), # Allow nulls
|
| 211 |
+
'query_rephrased_status': Value(dtype='string', id=None), # Allow nulls
|
| 212 |
+
# Add other columns from your original dataset here...
|
| 213 |
+
# Example: 'answer': Value(dtype='string', id=None),
|
| 214 |
+
# Example: 'subject': Value(dtype='string', id=None),
|
| 215 |
+
# IMPORTANT: List all columns present in your loaded dataset
|
| 216 |
+
'query_code': Value(dtype='string', id=None),
|
| 217 |
+
'answer': Value(dtype='string', id=None),
|
| 218 |
+
'answer_code': Value(dtype='string', id=None),
|
| 219 |
+
'subject': Value(dtype='string', id=None),
|
| 220 |
+
'grade': Value(dtype='string', id=None),
|
| 221 |
+
'source': Value(dtype='string', id=None),
|
| 222 |
+
'split': Value(dtype='string', id=None),
|
| 223 |
+
'__index_level_0__': Value(dtype='int64', id=None) # Check if this column exists
|
| 224 |
+
})
|
| 225 |
+
|
| 226 |
+
# Clean data slightly - replace python None with "" for string fields if needed by Arrow
|
| 227 |
+
# or ensure feature definition handles nulls correctly (Value(dtype='string', id=None) should)
|
| 228 |
+
# cleaned_data_list = []
|
| 229 |
+
# for item in data_list:
|
| 230 |
+
# cleaned_item = item.copy()
|
| 231 |
+
# for key, feature_type in features.items():
|
| 232 |
+
# if isinstance(feature_type, Value) and feature_type.dtype == 'string':
|
| 233 |
+
# if cleaned_item.get(key) is None:
|
| 234 |
+
# cleaned_item[key] = "" # Or keep None if schema allows
|
| 235 |
+
# cleaned_data_list.append(cleaned_item)
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
# Use the original list directly if schema handles None
|
| 239 |
+
processed_dataset = Dataset.from_list(list(data_list), features=features)
|
| 240 |
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
| 241 |
+
processed_dataset.save_to_disk(output_path)
|
| 242 |
+
logging.info(f"Successfully saved final dataset ({len(data_list)} items) to {output_path}")
|
| 243 |
+
return True
|
| 244 |
+
except Exception as e:
|
| 245 |
+
logging.error(f"Failed to save final dataset to {output_path}: {e}", exc_info=True)
|
| 246 |
+
# Try saving as JSON as a fallback
|
| 247 |
+
fallback_json_path = output_path + ".jsonl"
|
| 248 |
+
logging.warning(f"Attempting fallback save to JSON Lines file: {fallback_json_path}")
|
| 249 |
+
try:
|
| 250 |
+
with open(fallback_json_path, 'w', encoding='utf-8') as f:
|
| 251 |
+
for item in data_list:
|
| 252 |
+
f.write(json.dumps(item, ensure_ascii=False) + '\n')
|
| 253 |
+
logging.info(f"Successfully saved fallback JSON Lines file to {fallback_json_path}")
|
| 254 |
+
except Exception as json_e:
|
| 255 |
+
logging.error(f"Fallback JSON save also failed: {json_e}", exc_info=True)
|
| 256 |
+
return False
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
# --- Helper to get original user query ---
|
| 260 |
+
def get_user_query(example):
|
| 261 |
+
"""Extracts the user query content from the 'query' list."""
|
| 262 |
+
query_list = example.get("query")
|
| 263 |
+
if isinstance(query_list, list):
|
| 264 |
+
for message in query_list:
|
| 265 |
+
if isinstance(message, dict) and message.get("role") == "user":
|
| 266 |
+
content = message.get("content")
|
| 267 |
+
if isinstance(content, str) and content.strip():
|
| 268 |
+
return content
|
| 269 |
+
return None
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
# --- Function to Check if Retry is Needed ---
|
| 273 |
+
def needs_retry(example):
|
| 274 |
+
"""Determines if an example needs reprocessing based on its current state."""
|
| 275 |
+
status = example.get('query_rephrased_status')
|
| 276 |
+
rephrased_text = example.get('query_rephrased')
|
| 277 |
+
|
| 278 |
+
# Condition 1: Explicit failure status from previous (new script) run
|
| 279 |
+
if status in ['failed_llm_call', 'failed_llm_retry', 'failed_processing_exception']:
|
| 280 |
+
return True
|
| 281 |
+
|
| 282 |
+
# Condition 2: Certain 'skipped' statuses might warrant a retry (optional, adjust as needed)
|
| 283 |
+
# For example, if the user content was invalid originally, retrying won't help.
|
| 284 |
+
# if status in ['skipped_no_user_content_found']: # Decide if these should be retried
|
| 285 |
+
# return True
|
| 286 |
+
|
| 287 |
+
# Condition 3: Status indicates success OR status is missing/old,
|
| 288 |
+
# BUT the rephrased text is missing or empty. This catches failures
|
| 289 |
+
# from the *old* script or inconsistent states.
|
| 290 |
+
if rephrased_text is None or not str(rephrased_text).strip():
|
| 291 |
+
# Don't retry if it was intentionally skipped due to bad input
|
| 292 |
+
if status not in ['skipped_missing_query_column', 'skipped_query_not_list',
|
| 293 |
+
'skipped_query_list_empty', 'skipped_invalid_user_content',
|
| 294 |
+
'skipped_no_user_content_found']:
|
| 295 |
+
return True
|
| 296 |
+
|
| 297 |
+
# Condition 4 (Optional but recommended): Check if rephrased text is identical to original user query
|
| 298 |
+
# This requires extracting the original query here.
|
| 299 |
+
# original_user_query = get_user_query(example)
|
| 300 |
+
# if original_user_query and isinstance(rephrased_text, str) and \
|
| 301 |
+
# rephrased_text.strip().lower() == original_user_query.strip().lower():
|
| 302 |
+
# # Check status first - if it was intentionally skipped, don't retry
|
| 303 |
+
# if status not in ['skipped_missing_query_column', 'skipped_query_not_list',
|
| 304 |
+
# 'skipped_query_list_empty', 'skipped_invalid_user_content',
|
| 305 |
+
# 'skipped_no_user_content_found']:
|
| 306 |
+
# logging.debug(f"Identified identical query/rephrased text for retry: {original_user_query[:50]}...")
|
| 307 |
+
# return True
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
# Default: No retry needed
|
| 311 |
+
return False
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
# --- Main Execution ---
|
| 315 |
+
if __name__ == "__main__":
|
| 316 |
+
start_time = time.time()
|
| 317 |
+
logging.info("======================================================")
|
| 318 |
+
logging.info(f" Starting Dataset Processing Script in RETRY MODE")
|
| 319 |
+
logging.info("======================================================")
|
| 320 |
+
logging.info(f"Dataset: {DATASET_NAME}, Split: {DATASET_SPLIT}")
|
| 321 |
+
logging.info(f"Loading existing processed data from: {PROCESSED_DATA_PATH}")
|
| 322 |
+
logging.info(f"Final output will be saved to: {FINAL_OUTPUT_PATH}")
|
| 323 |
+
logging.info(f"Max concurrent workers: {MAX_WORKERS}")
|
| 324 |
+
|
| 325 |
+
# --- Load Existing Processed Dataset ---
|
| 326 |
+
if not os.path.exists(PROCESSED_DATA_PATH):
|
| 327 |
+
logging.error(f"Existing processed data not found at '{PROCESSED_DATA_PATH}'. Cannot run in retry mode.")
|
| 328 |
+
sys.exit(1)
|
| 329 |
+
|
| 330 |
+
logging.info(f"Loading existing dataset from {PROCESSED_DATA_PATH}...")
|
| 331 |
+
try:
|
| 332 |
+
# Load the dataset saved by the previous script run
|
| 333 |
+
existing_dataset = Dataset.load_from_disk(PROCESSED_DATA_PATH)
|
| 334 |
+
# Convert to list of dictionaries for easier modification access by index
|
| 335 |
+
# Be mindful of memory usage for very large datasets
|
| 336 |
+
results_list = existing_dataset.to_list()
|
| 337 |
+
total_examples = len(results_list)
|
| 338 |
+
logging.info(f"Loaded {total_examples} examples.")
|
| 339 |
+
# Ensure essential columns exist, add them if missing from old format
|
| 340 |
+
for i in range(total_examples):
|
| 341 |
+
if 'query_rephrased' not in results_list[i]:
|
| 342 |
+
results_list[i]['query_rephrased'] = None
|
| 343 |
+
if 'query_rephrased_status' not in results_list[i]:
|
| 344 |
+
results_list[i]['query_rephrased_status'] = 'unknown_original_status'
|
| 345 |
+
|
| 346 |
+
except Exception as e:
|
| 347 |
+
logging.error(f"Failed to load existing dataset from {PROCESSED_DATA_PATH}: {e}", exc_info=True)
|
| 348 |
+
sys.exit(1)
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
# --- Identify Indices to Retry ---
|
| 352 |
+
indices_to_retry = [
|
| 353 |
+
i for i, example in enumerate(results_list) if needs_retry(example)
|
| 354 |
+
]
|
| 355 |
+
num_to_retry = len(indices_to_retry)
|
| 356 |
+
|
| 357 |
+
if num_to_retry == 0:
|
| 358 |
+
logging.info("No examples found needing retry based on the criteria.")
|
| 359 |
+
logging.info(f"The dataset at {PROCESSED_DATA_PATH} is considered final.")
|
| 360 |
+
# Optional: You might still want to save it to FINAL_OUTPUT_PATH for consistency
|
| 361 |
+
# if not os.path.exists(FINAL_OUTPUT_PATH):
|
| 362 |
+
# save_final_dataset(results_list, FINAL_OUTPUT_PATH)
|
| 363 |
+
sys.exit(0)
|
| 364 |
+
|
| 365 |
+
logging.info(f"Identified {num_to_retry} examples to retry out of {total_examples}.")
|
| 366 |
+
|
| 367 |
+
# --- Prepare for Concurrent Retries ---
|
| 368 |
+
processed_count_in_retry = 0
|
| 369 |
+
# We don't need batch saving in the same way, but can update the list in memory
|
| 370 |
+
# A temporary dictionary to store results from futures before updating the main list
|
| 371 |
+
retry_results_dict = {}
|
| 372 |
+
|
| 373 |
+
logging.info("Starting concurrent processing for examples needing retry...")
|
| 374 |
+
|
| 375 |
+
with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
|
| 376 |
+
# Submit jobs only for the indices needing retry
|
| 377 |
+
# Pass the *specific example dictionary* to the function
|
| 378 |
+
futures = {
|
| 379 |
+
executor.submit(rephrase_query_entry, results_list[i]): i
|
| 380 |
+
for i in indices_to_retry
|
| 381 |
+
}
|
| 382 |
+
|
| 383 |
+
try:
|
| 384 |
+
pbar = tqdm(total=num_to_retry, desc="Retrying failed examples", unit="example")
|
| 385 |
+
for future in concurrent.futures.as_completed(futures):
|
| 386 |
+
original_index = futures[future] # Get the index in the full results_list
|
| 387 |
+
try:
|
| 388 |
+
# Get the updated dictionary result from the retry attempt
|
| 389 |
+
updated_example_dict = future.result()
|
| 390 |
+
# Store the result temporarily, keyed by original index
|
| 391 |
+
retry_results_dict[original_index] = updated_example_dict
|
| 392 |
+
pbar.set_postfix({"LastStatus": updated_example_dict.get('query_rephrased_status', 'N/A')}, refresh=True)
|
| 393 |
+
|
| 394 |
+
except Exception as exc:
|
| 395 |
+
# Catch errors *during* the retry processing itself
|
| 396 |
+
logging.error(f'Retry for example index {original_index} generated an exception: {exc}', exc_info=True)
|
| 397 |
+
# Create a placeholder indicating the retry attempt failed due to an exception
|
| 398 |
+
error_placeholder = results_list[original_index].copy() # Get original data again
|
| 399 |
+
error_placeholder['query_rephrased_status'] = f'failed_retry_exception_{type(exc).__name__}'
|
| 400 |
+
# Store this error placeholder
|
| 401 |
+
retry_results_dict[original_index] = error_placeholder
|
| 402 |
+
pbar.set_postfix({"LastStatus": error_placeholder['query_rephrased_status']}, refresh=True)
|
| 403 |
+
|
| 404 |
+
finally:
|
| 405 |
+
processed_count_in_retry += 1
|
| 406 |
+
pbar.update(1)
|
| 407 |
+
# Optional intermediate save logic (maybe save every N retries)
|
| 408 |
+
# Could save the *entire* potentially partially updated list, but might be slow.
|
| 409 |
+
# if processed_count_in_retry % BATCH_SAVE_SIZE == 0:
|
| 410 |
+
# logging.info(f"Processed {processed_count_in_retry} retries, updating intermediate state...")
|
| 411 |
+
# # Update the main list with results gathered so far
|
| 412 |
+
# for idx, updated_item in retry_results_dict.items():
|
| 413 |
+
# results_list[idx] = updated_item
|
| 414 |
+
# # Clear the temporary dict after updating
|
| 415 |
+
# retry_results_dict.clear()
|
| 416 |
+
# # Save the whole list (potentially slow)
|
| 417 |
+
# save_final_dataset(results_list, FINAL_OUTPUT_PATH + "_interim")
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
except KeyboardInterrupt:
|
| 421 |
+
logging.warning("\nCtrl+C detected during retry! Attempting to save progress...")
|
| 422 |
+
|
| 423 |
+
except Exception as e:
|
| 424 |
+
logging.error(f"An unexpected error occurred during the retry loop: {e}", exc_info=True)
|
| 425 |
+
|
| 426 |
+
finally:
|
| 427 |
+
if 'pbar' in locals():
|
| 428 |
+
pbar.close()
|
| 429 |
+
|
| 430 |
+
# --- Update the main results list with all completed retries ---
|
| 431 |
+
logging.info("Updating main results list with completed retry attempts...")
|
| 432 |
+
update_count = 0
|
| 433 |
+
for idx, updated_item in retry_results_dict.items():
|
| 434 |
+
if idx < len(results_list):
|
| 435 |
+
results_list[idx] = updated_item
|
| 436 |
+
update_count += 1
|
| 437 |
+
else:
|
| 438 |
+
logging.error(f"Index {idx} from retry results is out of bounds for results_list (size {len(results_list)}). Skipping update.")
|
| 439 |
+
|
| 440 |
+
logging.info(f"Applied updates for {update_count} retried items.")
|
| 441 |
+
|
| 442 |
+
# --- Final Save ---
|
| 443 |
+
logging.info(f"Attempting to save the final updated dataset to: {FINAL_OUTPUT_PATH}")
|
| 444 |
+
if save_final_dataset(results_list, FINAL_OUTPUT_PATH):
|
| 445 |
+
logging.info("Final dataset saved successfully.")
|
| 446 |
+
# Optional: Suggest deleting the old intermediate path if successful
|
| 447 |
+
# logging.info(f"You may now safely remove the intermediate directory: {PROCESSED_DATA_PATH}")
|
| 448 |
+
else:
|
| 449 |
+
logging.error(">>> FINAL SAVE FAILED! <<<")
|
| 450 |
+
logging.error(f"Check the logs. The latest state might be in memory or a fallback JSON file if created.")
|
| 451 |
+
|
| 452 |
+
|
| 453 |
+
# --- Final Verification (Optional) ---
|
| 454 |
+
logging.info("------------------------------------------------------")
|
| 455 |
+
logging.info("Verification: Loading final saved dataset for status check...")
|
| 456 |
+
try:
|
| 457 |
+
final_reloaded_dataset = Dataset.load_from_disk(FINAL_OUTPUT_PATH)
|
| 458 |
+
logging.info(f"Successfully reloaded final dataset with {len(final_reloaded_dataset)} examples from {FINAL_OUTPUT_PATH}.")
|
| 459 |
+
status_counts = {}
|
| 460 |
+
for ex in final_reloaded_dataset:
|
| 461 |
+
status = ex.get('query_rephrased_status', 'unknown_status_field')
|
| 462 |
+
status_counts[status] = status_counts.get(status, 0) + 1
|
| 463 |
+
|
| 464 |
+
logging.info("Status counts in the final saved file:")
|
| 465 |
+
for status, count in sorted(status_counts.items()):
|
| 466 |
+
logging.info(f" - {status}: {count}")
|
| 467 |
+
|
| 468 |
+
# Highlight remaining failures
|
| 469 |
+
remaining_failures = status_counts.get('failed_llm_retry', 0) + \
|
| 470 |
+
status_counts.get('failed_retry_exception', 0) + \
|
| 471 |
+
status_counts.get('failed_llm_call', 0) # Include original failures if not retried/still failing
|
| 472 |
+
|
| 473 |
+
if remaining_failures > 0:
|
| 474 |
+
logging.warning(f"Found {remaining_failures} examples still marked as failed after retry attempts.")
|
| 475 |
+
else:
|
| 476 |
+
logging.info("All identified failures appear to have been successfully retried or were not retried.")
|
| 477 |
+
|
| 478 |
+
except FileNotFoundError:
|
| 479 |
+
logging.error(f"Verification failed: Final saved dataset not found at {FINAL_OUTPUT_PATH}.")
|
| 480 |
+
except Exception as e:
|
| 481 |
+
logging.error(f"Failed to reload or verify final dataset from {FINAL_OUTPUT_PATH}: {e}", exc_info=True)
|
| 482 |
+
|
| 483 |
+
|
| 484 |
+
end_time = time.time()
|
| 485 |
+
logging.info("------------------------------------------------------")
|
| 486 |
+
logging.info(f"Retry script finished in {end_time - start_time:.2f} seconds.")
|
| 487 |
+
logging.info("======================================================")
|
r1-a/dataset/final_tts.py
ADDED
|
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --- ENVIRONMENT VARIABLE CONTROL ---
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
import random
|
| 5 |
+
import torch
|
| 6 |
+
import torchaudio
|
| 7 |
+
from datasets import load_dataset, Dataset, load_from_disk, Features, Value, Audio
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
import time
|
| 10 |
+
import logging
|
| 11 |
+
import json
|
| 12 |
+
import math
|
| 13 |
+
import pathlib
|
| 14 |
+
import re
|
| 15 |
+
import unicodedata
|
| 16 |
+
|
| 17 |
+
# --- Read Environment Variables ---
|
| 18 |
+
try:
|
| 19 |
+
# GPU ID for this specific run (0, 1, 2, or 3)
|
| 20 |
+
PROCESSING_GPU_ID = 3
|
| 21 |
+
# Total number of parallel runs (should be 4)
|
| 22 |
+
TOTAL_PROCESSING_NODES = 4
|
| 23 |
+
except ValueError:
|
| 24 |
+
print("Error: PROCESSING_GPU_ID and TOTAL_PROCESSING_NODES env vars must be integers.")
|
| 25 |
+
sys.exit(1)
|
| 26 |
+
|
| 27 |
+
if not 0 <= PROCESSING_GPU_ID < TOTAL_PROCESSING_NODES:
|
| 28 |
+
print(f"Error: PROCESSING_GPU_ID ({PROCESSING_GPU_ID}) must be between 0 and {TOTAL_PROCESSING_NODES - 1}.")
|
| 29 |
+
sys.exit(1)
|
| 30 |
+
|
| 31 |
+
print(f"--- Starting Run for Shard {PROCESSING_GPU_ID + 1} / {TOTAL_PROCESSING_NODES} ---")
|
| 32 |
+
print(f"--- Targetting GPU Index (physical): {PROCESSING_GPU_ID} ---")
|
| 33 |
+
|
| 34 |
+
# --- SET VISIBLE CUDA DEVICE *BEFORE* TORCH IMPORT THAT USES CUDA ---
|
| 35 |
+
# This makes the chosen GPU appear as 'cuda:0' to this script instance
|
| 36 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = str(PROCESSING_GPU_ID)
|
| 37 |
+
|
| 38 |
+
# Check CUDA availability *after* setting visibility
|
| 39 |
+
if not torch.cuda.is_available():
|
| 40 |
+
print(f"ERROR: CUDA device {PROCESSING_GPU_ID} is not available after setting CUDA_VISIBLE_DEVICES.")
|
| 41 |
+
sys.exit(1)
|
| 42 |
+
else:
|
| 43 |
+
# PyTorch now sees the selected GPU as cuda:0
|
| 44 |
+
effective_device = torch.device("cuda:0")
|
| 45 |
+
try:
|
| 46 |
+
print(f"Script process {os.getpid()} successfully assigned to specific GPU: {torch.cuda.get_device_name(0)} (Original Index {PROCESSING_GPU_ID})")
|
| 47 |
+
except Exception as e:
|
| 48 |
+
print(f"Warning: Could not get device name, but CUDA is available. Error: {e}")
|
| 49 |
+
print(f"Script process {os.getpid()} assigned to specific GPU index {PROCESSING_GPU_ID}")
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
# --- Add CosyVoice Path ---
|
| 53 |
+
COSYVOICE_PATH = '/home/chenyifu/CosyVoice' # <-- Your path
|
| 54 |
+
if COSYVOICE_PATH not in sys.path:
|
| 55 |
+
sys.path.append(COSYVOICE_PATH)
|
| 56 |
+
|
| 57 |
+
# Import CosyVoice
|
| 58 |
+
try:
|
| 59 |
+
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
|
| 60 |
+
from cosyvoice.utils.file_utils import load_wav
|
| 61 |
+
except ImportError as e:
|
| 62 |
+
print(f"Error importing CosyVoice: {e}")
|
| 63 |
+
sys.exit(1)
|
| 64 |
+
|
| 65 |
+
# Setup basic logging for this instance
|
| 66 |
+
logging.basicConfig(level=logging.INFO, format=f'%(asctime)s - %(levelname)s - [GPU-{PROCESSING_GPU_ID}] %(message)s')
|
| 67 |
+
|
| 68 |
+
# ------------------------
|
| 69 |
+
# 配置参数 (Configuration Parameters) - Mostly unchanged
|
| 70 |
+
# ------------------------
|
| 71 |
+
# --- Input Dataset ---
|
| 72 |
+
INPUT_DATASET_PATH = '/home/chenyifu/audio-r1/r1-a/dataset/prompt_only'
|
| 73 |
+
TEXT_FIELD_FOR_TTS = "question_text"
|
| 74 |
+
AUDIO_PATH_FIELD = "question_audio" # Field name for the *final* aggregated dataset
|
| 75 |
+
ASSUMED_INPUT_SPLIT_NAME = "train"
|
| 76 |
+
|
| 77 |
+
# --- Output ---
|
| 78 |
+
TTS_OUTPUT_BASE_PATH = '/home/chenyifu/audio-r1/r1-a/dataset/prompt_only_fully_merged_with_audio' # <<-- SHARED output path for all runs
|
| 79 |
+
AUDIO_SUBFOLDER_NAME = 'audio_files' # <<-- SHARED audio subfolder
|
| 80 |
+
|
| 81 |
+
# --- Prompt Audio Settings ---
|
| 82 |
+
RAW_PROMPT_DATASET_PATH = "/home/chenyifu/audio-r1/r1-a/dataset/mls_eng10k"
|
| 83 |
+
PROMPT_DATASET_SPLIT = "train"
|
| 84 |
+
PROMPT_MIN_DURATION_S = 10
|
| 85 |
+
PROMPT_MAX_DURATION_S = 13
|
| 86 |
+
PROMPT_TEXT_FIELD = "transcript"
|
| 87 |
+
PROMPT_AUDIO_DURATION_FIELD = "audio_duration"
|
| 88 |
+
FILTERED_PROMPT_DATASET_PATH = f"{RAW_PROMPT_DATASET_PATH}_filtered_{PROMPT_MIN_DURATION_S}_{PROMPT_MAX_DURATION_S}s" # SHARED path
|
| 89 |
+
|
| 90 |
+
# --- TTS Settings ---
|
| 91 |
+
TARGET_SAMPLE_RATE = 16000 # Desired *prompt* sample rate
|
| 92 |
+
MAX_TTS_RETRIES = 3
|
| 93 |
+
RETRY_DELAY_SECONDS = 2
|
| 94 |
+
|
| 95 |
+
# --- Processing Settings ---
|
| 96 |
+
# TEST_SINGLE_SAMPLE = False # Not needed, shard logic handles subset
|
| 97 |
+
# MULTI_GPU_PROCESSING = False # Not using mp.spawn
|
| 98 |
+
|
| 99 |
+
# ------------------------
|
| 100 |
+
# 辅助函数 (Helper Functions) - Unchanged from previous multi-GPU capable script
|
| 101 |
+
# ------------------------
|
| 102 |
+
def preprocess_text(text):
|
| 103 |
+
if not isinstance(text, str): return ""
|
| 104 |
+
text = unicodedata.normalize('NFKC', text)
|
| 105 |
+
text = re.sub(r'[—–―‐‑⁃﹣-]', ' ', text)
|
| 106 |
+
text = text.replace('\u00AD', '').replace('\u200B', '')
|
| 107 |
+
text = re.sub(r'\s+', ' ', text).strip()
|
| 108 |
+
if text and text[-1] not in ['.', '?', '!']: text += '.'
|
| 109 |
+
return text
|
| 110 |
+
|
| 111 |
+
def filter_prompt_logic(example):
|
| 112 |
+
if PROMPT_AUDIO_DURATION_FIELD in example and isinstance(example[PROMPT_AUDIO_DURATION_FIELD], (int, float)):
|
| 113 |
+
duration = example[PROMPT_AUDIO_DURATION_FIELD]; return PROMPT_MIN_DURATION_S <= duration <= PROMPT_MAX_DURATION_S
|
| 114 |
+
else:
|
| 115 |
+
try:
|
| 116 |
+
audio_info = example['audio']; samplerate = audio_info['sampling_rate']; duration = len(audio_info['array']) / samplerate
|
| 117 |
+
return PROMPT_MIN_DURATION_S <= duration <= PROMPT_MAX_DURATION_S
|
| 118 |
+
except: return False
|
| 119 |
+
|
| 120 |
+
def get_random_valid_prompt(filtered_prompt_dataset, target_sample_rate=TARGET_SAMPLE_RATE):
|
| 121 |
+
if not filtered_prompt_dataset or len(filtered_prompt_dataset) == 0: raise ValueError("Filtered prompt dataset empty!")
|
| 122 |
+
idx = random.randint(0, len(filtered_prompt_dataset) - 1)
|
| 123 |
+
try: sample = filtered_prompt_dataset[idx]
|
| 124 |
+
except IndexError: sample = filtered_prompt_dataset[0]
|
| 125 |
+
audio_info = sample['audio']; prompt_text = sample[PROMPT_TEXT_FIELD]
|
| 126 |
+
if isinstance(audio_info, dict) and 'array' in audio_info: waveform = torch.tensor(audio_info['array'], dtype=torch.float32); sr = audio_info['sampling_rate']
|
| 127 |
+
elif isinstance(audio_info, str) or (isinstance(audio_info, dict) and 'path' in audio_info):
|
| 128 |
+
path = audio_info if isinstance(audio_info, str) else audio_info['path']; waveform, sr = torchaudio.load(path)
|
| 129 |
+
else: raise TypeError("Unknown prompt audio format")
|
| 130 |
+
if not prompt_text or waveform.numel() == 0: return get_random_valid_prompt(filtered_prompt_dataset, target_sample_rate)
|
| 131 |
+
if sr != target_sample_rate:
|
| 132 |
+
if waveform.dim() > 1 and waveform.shape[0] > 1: waveform = waveform.mean(dim=0, keepdim=True)
|
| 133 |
+
elif waveform.dim() == 1: waveform = waveform.unsqueeze(0)
|
| 134 |
+
resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sample_rate); waveform = resampler(waveform)
|
| 135 |
+
if waveform.dim()==1: waveform = waveform.unsqueeze(0)
|
| 136 |
+
elif waveform.shape[0] > 1 : waveform = waveform.mean(dim=0, keepdim=True)
|
| 137 |
+
return waveform.cpu(), prompt_text
|
| 138 |
+
|
| 139 |
+
def text_to_audio(text_to_convert, cosyvoice, filtered_prompt_dataset, target_sample_rate, stream=False, max_retries=MAX_TTS_RETRIES):
|
| 140 |
+
cleaned_text = preprocess_text(text_to_convert)
|
| 141 |
+
if not cleaned_text: logging.warning(f"Empty text after cleaning: '{text_to_convert[:60]}...'"); return None
|
| 142 |
+
last_exception = None
|
| 143 |
+
for attempt in range(max_retries):
|
| 144 |
+
try:
|
| 145 |
+
prompt_speech, prompt_text = get_random_valid_prompt(filtered_prompt_dataset, target_sample_rate)
|
| 146 |
+
all_speech = []
|
| 147 |
+
inference_generator = cosyvoice.inference_zero_shot(cleaned_text, prompt_text, prompt_speech, stream=stream)
|
| 148 |
+
for i, chunk in enumerate(inference_generator):
|
| 149 |
+
if 'tts_speech' in chunk and chunk['tts_speech'] is not None: all_speech.append(chunk['tts_speech'])
|
| 150 |
+
if not all_speech: raise ValueError(f"TTS produced no audio chunks. Cleaned: '{cleaned_text[:60]}...'")
|
| 151 |
+
combined_speech = torch.cat(all_speech, dim=-1); actual_sample_rate = cosyvoice.sample_rate
|
| 152 |
+
return {'audio_tensor': combined_speech, 'sample_rate': actual_sample_rate}
|
| 153 |
+
except Exception as e:
|
| 154 |
+
last_exception = e; logging.error(f"TTS Error Attempt {attempt + 1}: {e}", exc_info=False)
|
| 155 |
+
if torch.cuda.is_available(): torch.cuda.empty_cache()
|
| 156 |
+
if attempt < max_retries - 1: time.sleep(RETRY_DELAY_SECONDS * (1.5 ** attempt) + random.uniform(0.5, 1.5))
|
| 157 |
+
else: logging.error(f"All TTS retries failed for: '{cleaned_text[:60]}...'")
|
| 158 |
+
return None
|
| 159 |
+
|
| 160 |
+
# -----------------------------
|
| 161 |
+
# --- Main Execution Logic ----
|
| 162 |
+
# -----------------------------
|
| 163 |
+
if __name__ == "__main__":
|
| 164 |
+
# --- Load or Create Filtered Prompt Dataset (Safe for concurrent runs, only first one creates) ---
|
| 165 |
+
# Add a small delay + check to mitigate potential race condition on creation
|
| 166 |
+
if not os.path.exists(FILTERED_PROMPT_DATASET_PATH):
|
| 167 |
+
time.sleep(random.uniform(0, 2)) # Small random delay
|
| 168 |
+
if not os.path.exists(FILTERED_PROMPT_DATASET_PATH): # Double check
|
| 169 |
+
logging.info(f"Filtered prompt dataset not found. Attempting creation...")
|
| 170 |
+
try:
|
| 171 |
+
prompt_dataset_raw = load_dataset(RAW_PROMPT_DATASET_PATH, split=PROMPT_DATASET_SPLIT)
|
| 172 |
+
if 'audio' in prompt_dataset_raw.features and not isinstance(prompt_dataset_raw.features['audio'], Audio):
|
| 173 |
+
prompt_dataset_raw = prompt_dataset_raw.cast_column("audio", Audio(decode=True))
|
| 174 |
+
filtered_ds = prompt_dataset_raw.filter(filter_prompt_logic, num_proc=max(1, os.cpu_count() // 2))
|
| 175 |
+
if len(filtered_ds) == 0: raise ValueError("No prompts left after filtering.")
|
| 176 |
+
cols = ['audio', PROMPT_TEXT_FIELD]
|
| 177 |
+
if PROMPT_AUDIO_DURATION_FIELD in prompt_dataset_raw.column_names: cols.append(PROMPT_AUDIO_DURATION_FIELD)
|
| 178 |
+
filtered_ds = filtered_ds.select_columns(cols)
|
| 179 |
+
filtered_ds.save_to_disk(FILTERED_PROMPT_DATASET_PATH)
|
| 180 |
+
logging.info(f"Filtered prompt dataset CREATED and saved to: {FILTERED_PROMPT_DATASET_PATH}")
|
| 181 |
+
except Exception as e:
|
| 182 |
+
logging.error(f"FATAL: Failed to create filtered prompt dataset: {e}", exc_info=True)
|
| 183 |
+
# If creation fails, other processes might also fail loading. Check logs.
|
| 184 |
+
sys.exit(1)
|
| 185 |
+
else:
|
| 186 |
+
logging.info(f"Filtered prompt dataset appeared while waiting. Proceeding.")
|
| 187 |
+
|
| 188 |
+
try:
|
| 189 |
+
logging.info(f"Loading filtered prompt dataset from: {FILTERED_PROMPT_DATASET_PATH}")
|
| 190 |
+
filtered_prompt_dataset = load_from_disk(FILTERED_PROMPT_DATASET_PATH)
|
| 191 |
+
logging.info(f"Loaded {len(filtered_prompt_dataset)} filtered prompts.")
|
| 192 |
+
except Exception as e:
|
| 193 |
+
logging.error(f"FATAL: Failed to load filtered prompt dataset: {e}", exc_info=True)
|
| 194 |
+
sys.exit(1)
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
# --- Initialize TTS Model (on the assigned GPU 'cuda:0') ---
|
| 198 |
+
logging.info("Initializing CosyVoice model for this process...")
|
| 199 |
+
try:
|
| 200 |
+
# Model will initialize on cuda:0, which corresponds to the selected physical GPU
|
| 201 |
+
cosyvoice = CosyVoice2(
|
| 202 |
+
f'{COSYVOICE_PATH}/pretrained_models/CosyVoice2-0.5B',
|
| 203 |
+
load_jit=True, load_trt=False, fp16=False
|
| 204 |
+
)
|
| 205 |
+
model_output_sr = cosyvoice.sample_rate
|
| 206 |
+
logging.info(f"CosyVoice initialized. Model output SR: {model_output_sr}")
|
| 207 |
+
except Exception as e:
|
| 208 |
+
logging.error(f"Error initializing CosyVoice2 model: {e}", exc_info=True)
|
| 209 |
+
sys.exit(1)
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
# --- Load Main Input Dataset ---
|
| 213 |
+
logging.info(f"Loading main input dataset from: {INPUT_DATASET_PATH}")
|
| 214 |
+
try:
|
| 215 |
+
input_dataset_full = load_from_disk(INPUT_DATASET_PATH)
|
| 216 |
+
dataset_size = len(input_dataset_full)
|
| 217 |
+
logging.info(f"Loaded main dataset with {dataset_size} examples.")
|
| 218 |
+
|
| 219 |
+
# --- Add original index column ---
|
| 220 |
+
logging.info("Adding original index...")
|
| 221 |
+
def add_index(example, idx): example['original_index'] = idx; return example
|
| 222 |
+
input_dataset_with_indices = input_dataset_full.map(add_index, with_indices=True, num_proc=max(1, os.cpu_count() // 2))
|
| 223 |
+
logging.info("Original index added.")
|
| 224 |
+
|
| 225 |
+
# --- Shard the dataset for this specific run ---
|
| 226 |
+
logging.info(f"Selecting shard {PROCESSING_GPU_ID + 1} / {TOTAL_PROCESSING_NODES} for processing...")
|
| 227 |
+
dataset_shard = input_dataset_with_indices.shard(
|
| 228 |
+
num_shards=TOTAL_PROCESSING_NODES,
|
| 229 |
+
index=PROCESSING_GPU_ID,
|
| 230 |
+
contiguous=True # Potentially faster access
|
| 231 |
+
)
|
| 232 |
+
shard_size = len(dataset_shard)
|
| 233 |
+
logging.info(f"This instance will process {shard_size} samples.")
|
| 234 |
+
|
| 235 |
+
except Exception as e:
|
| 236 |
+
logging.error(f"Error loading or sharding main input dataset: {e}", exc_info=True)
|
| 237 |
+
sys.exit(1)
|
| 238 |
+
|
| 239 |
+
# --- Define Shared Output Audio Directory ---
|
| 240 |
+
# All instances write to the same place
|
| 241 |
+
split_output_dir = os.path.join(TTS_OUTPUT_BASE_PATH, ASSUMED_INPUT_SPLIT_NAME)
|
| 242 |
+
split_audio_dir = os.path.join(split_output_dir, AUDIO_SUBFOLDER_NAME)
|
| 243 |
+
os.makedirs(split_audio_dir, exist_ok=True) # Ensure it exists
|
| 244 |
+
|
| 245 |
+
# --- Process the Assigned Shard ---
|
| 246 |
+
logging.info(f"Starting TTS processing for shard {PROCESSING_GPU_ID + 1}...")
|
| 247 |
+
pbar = tqdm(total=shard_size, desc=f"GPU-{PROCESSING_GPU_ID} TTS", ncols=100)
|
| 248 |
+
for sample in dataset_shard:
|
| 249 |
+
try:
|
| 250 |
+
original_idx = sample['original_index']
|
| 251 |
+
text_to_convert = sample.get(TEXT_FIELD_FOR_TTS)
|
| 252 |
+
|
| 253 |
+
if not text_to_convert or not isinstance(text_to_convert, str) or not text_to_convert.strip():
|
| 254 |
+
logging.warning(f"Skipping original index {original_idx}: missing/invalid text.")
|
| 255 |
+
pbar.update(1)
|
| 256 |
+
continue
|
| 257 |
+
|
| 258 |
+
# Define audio path using original index -> SHARED audio dir
|
| 259 |
+
audio_filename = f"query_{original_idx}.wav"
|
| 260 |
+
absolute_audio_path = os.path.join(split_audio_dir, audio_filename)
|
| 261 |
+
|
| 262 |
+
# Check if audio already exists (supports resuming any run)
|
| 263 |
+
if os.path.exists(absolute_audio_path):
|
| 264 |
+
logging.debug(f"Audio exists for original index {original_idx}, skipping.")
|
| 265 |
+
pbar.update(1)
|
| 266 |
+
continue
|
| 267 |
+
|
| 268 |
+
# Perform TTS
|
| 269 |
+
tts_result = text_to_audio(
|
| 270 |
+
text_to_convert,
|
| 271 |
+
cosyvoice,
|
| 272 |
+
filtered_prompt_dataset,
|
| 273 |
+
TARGET_SAMPLE_RATE, # Prompt SR
|
| 274 |
+
stream=False
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
if tts_result is not None:
|
| 278 |
+
audio_tensor = tts_result['audio_tensor'] # GPU tensor
|
| 279 |
+
output_sample_rate = tts_result['sample_rate'] # Use model's actual SR
|
| 280 |
+
|
| 281 |
+
try:
|
| 282 |
+
audio_tensor_cpu = audio_tensor.detach().cpu().to(torch.float32)
|
| 283 |
+
if audio_tensor_cpu.dim() == 1: audio_tensor_cpu = audio_tensor_cpu.unsqueeze(0)
|
| 284 |
+
elif audio_tensor_cpu.dim() > 2: audio_tensor_cpu = audio_tensor_cpu.view(1, -1)
|
| 285 |
+
|
| 286 |
+
# Save to the SHARED audio directory
|
| 287 |
+
torchaudio.save(absolute_audio_path, audio_tensor_cpu, output_sample_rate)
|
| 288 |
+
logging.debug(f"Saved audio for original index {original_idx}")
|
| 289 |
+
|
| 290 |
+
del audio_tensor, audio_tensor_cpu
|
| 291 |
+
if torch.cuda.is_available(): torch.cuda.empty_cache()
|
| 292 |
+
|
| 293 |
+
except Exception as e:
|
| 294 |
+
logging.error(f"Failed to save audio for original index {original_idx}: {e}")
|
| 295 |
+
if 'audio_tensor' in locals(): del audio_tensor
|
| 296 |
+
if 'audio_tensor_cpu' in locals(): del audio_tensor_cpu
|
| 297 |
+
if torch.cuda.is_available(): torch.cuda.empty_cache()
|
| 298 |
+
else:
|
| 299 |
+
logging.warning(f"TTS failed for original index {original_idx} after retries.")
|
| 300 |
+
|
| 301 |
+
except Exception as e:
|
| 302 |
+
logging.error(f"Unexpected error processing sample for original index {sample.get('original_index', 'UNKNOWN')}: {e}", exc_info=True)
|
| 303 |
+
if torch.cuda.is_available(): torch.cuda.empty_cache()
|
| 304 |
+
|
| 305 |
+
finally:
|
| 306 |
+
pbar.update(1)
|
| 307 |
+
|
| 308 |
+
pbar.close()
|
| 309 |
+
logging.info(f"--- Finished processing shard {PROCESSING_GPU_ID + 1} / {TOTAL_PROCESSING_NODES} ---")
|
| 310 |
+
|
| 311 |
+
# --- End of Script ---
|
| 312 |
+
print("="*30)
|
| 313 |
+
print(f"Run for GPU {PROCESSING_GPU_ID} completed.")
|
| 314 |
+
print(f"Audio files (if successful) saved in: {split_audio_dir}")
|
| 315 |
+
print("IMPORTANT: Run the aggregation script AFTER all 4 runs are finished to create the final dataset object.")
|
| 316 |
+
print("="*30)
|
r1-a/dataset/gsm8k.py
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import random
|
| 3 |
+
import torch
|
| 4 |
+
import torchaudio
|
| 5 |
+
from datasets import load_dataset, Dataset
|
| 6 |
+
import sys
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
|
| 9 |
+
sys.path.append('/root/autodl-tmp/CosyVoice')
|
| 10 |
+
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
|
| 11 |
+
from cosyvoice.utils.file_utils import load_wav
|
| 12 |
+
|
| 13 |
+
# 配置参数
|
| 14 |
+
COMMON_VOICE_LANGUAGE = "en"
|
| 15 |
+
DATASET_NAME = "gsm8k" # 使用 gsm8k 数据集
|
| 16 |
+
OUTPUT_DATASET_PATH = './gsm8k_with_audio'
|
| 17 |
+
SAMPLE_RATE = 16000
|
| 18 |
+
|
| 19 |
+
# --- 辅助函数 ---
|
| 20 |
+
|
| 21 |
+
def get_random_prompt(common_voice_dataset, sample_rate=SAMPLE_RATE):
|
| 22 |
+
"""
|
| 23 |
+
从 Common Voice 数据集中随机抽取一条语音及对应文本作为 prompt。
|
| 24 |
+
"""
|
| 25 |
+
idx = random.randint(0, len(common_voice_dataset) - 1)
|
| 26 |
+
sample = common_voice_dataset.select([idx])[0]
|
| 27 |
+
audio = sample['audio']
|
| 28 |
+
waveform = torch.tensor(audio['array'], dtype=torch.float32)
|
| 29 |
+
sr = audio['sampling_rate']
|
| 30 |
+
if sr != sample_rate:
|
| 31 |
+
resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate)
|
| 32 |
+
waveform = resampler(waveform)
|
| 33 |
+
return waveform.unsqueeze(0), sample['raw_text']
|
| 34 |
+
|
| 35 |
+
def text_to_audio(query_text, cosyvoice, common_voice_dataset, stream=False):
|
| 36 |
+
"""
|
| 37 |
+
利用 CosyVoice2 模型将输入文本转换为语音,采用随机 prompt 进行 zero-shot 推理。
|
| 38 |
+
"""
|
| 39 |
+
try:
|
| 40 |
+
prompt_speech, prompt_text = get_random_prompt(common_voice_dataset, sample_rate=SAMPLE_RATE)
|
| 41 |
+
# 可选:保存 prompt.wav 进行调试
|
| 42 |
+
# torchaudio.save('prompt.wav', prompt_speech, SAMPLE_RATE)
|
| 43 |
+
all_speech = []
|
| 44 |
+
for i, j in enumerate(cosyvoice.inference_zero_shot(
|
| 45 |
+
query_text,
|
| 46 |
+
prompt_text,
|
| 47 |
+
prompt_speech,
|
| 48 |
+
stream=stream,
|
| 49 |
+
text_frontend=False
|
| 50 |
+
)):
|
| 51 |
+
all_speech.append(j['tts_speech'])
|
| 52 |
+
# 合并所有生成的语音片段为一个长 tensor
|
| 53 |
+
combined_speech = torch.cat(all_speech, dim=-1)
|
| 54 |
+
sample_rate_val = cosyvoice.sample_rate
|
| 55 |
+
return {'audio_tensor': combined_speech, 'sample_rate': sample_rate_val}
|
| 56 |
+
except Exception as e:
|
| 57 |
+
print(f"Error converting text to audio: {e}")
|
| 58 |
+
return None
|
| 59 |
+
|
| 60 |
+
def process_example(example, cosyvoice, common_voice_dataset, sample_rate=SAMPLE_RATE):
|
| 61 |
+
"""
|
| 62 |
+
针对 gsm8k 数据集中的单个样本进行 TTS 处理。
|
| 63 |
+
假设 gsm8k 数据集中的问题文本字段为 'question',
|
| 64 |
+
答案字段为 'answer'。
|
| 65 |
+
"""
|
| 66 |
+
query = example['question']
|
| 67 |
+
audio_result = text_to_audio(query, cosyvoice, common_voice_dataset, stream=False)
|
| 68 |
+
if audio_result is not None:
|
| 69 |
+
# 返回生成的音频 tensor 及采样率
|
| 70 |
+
return {
|
| 71 |
+
'audio_tensor': audio_result['audio_tensor'],
|
| 72 |
+
'sample_rate': audio_result['sample_rate']
|
| 73 |
+
}
|
| 74 |
+
else:
|
| 75 |
+
return None
|
| 76 |
+
|
| 77 |
+
# --- 数据加载与模型初始化 ---
|
| 78 |
+
|
| 79 |
+
print("Loading Common Voice dataset...")
|
| 80 |
+
common_voice = load_dataset("facebook/voxpopuli", "en", split='train')
|
| 81 |
+
print(f"Total Common Voice {COMMON_VOICE_LANGUAGE} samples: {len(common_voice)}")
|
| 82 |
+
|
| 83 |
+
print("Initializing CosyVoice2 model...")
|
| 84 |
+
cosyvoice = CosyVoice2(
|
| 85 |
+
'/root/autodl-tmp/CosyVoice/pretrained_models/CosyVoice2-0.5B', # 替换为实际的模型路径
|
| 86 |
+
load_jit=True,
|
| 87 |
+
load_trt=False,
|
| 88 |
+
fp16=False
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
print("Loading GSM8K dataset...")
|
| 92 |
+
dataset = load_dataset("openai/gsm8k", 'main')
|
| 93 |
+
|
| 94 |
+
# 确保输出总目录存在
|
| 95 |
+
os.makedirs(OUTPUT_DATASET_PATH, exist_ok=True)
|
| 96 |
+
|
| 97 |
+
# --- 主处理循环 ---
|
| 98 |
+
# 对每个 split 分别处理,每个样本处理后保存 .wav 文件和记录最终数据集信息
|
| 99 |
+
final_dataset_dict = {} # 用于保存最终数据集的每个 split
|
| 100 |
+
|
| 101 |
+
for split_name, split_dataset in dataset.items():
|
| 102 |
+
print(f"Processing split: {split_name} with {len(split_dataset)} examples")
|
| 103 |
+
split_output_dir = os.path.join(OUTPUT_DATASET_PATH, split_name)
|
| 104 |
+
os.makedirs(split_output_dir, exist_ok=True)
|
| 105 |
+
|
| 106 |
+
# 用于断点续转的进度记录
|
| 107 |
+
progress_file = os.path.join(split_output_dir, "progress.txt")
|
| 108 |
+
start_index = 0
|
| 109 |
+
if os.path.exists(progress_file):
|
| 110 |
+
try:
|
| 111 |
+
with open(progress_file, "r") as f:
|
| 112 |
+
start_index = int(f.read().strip())
|
| 113 |
+
print(f"Resuming split '{split_name}' from sample index {start_index}")
|
| 114 |
+
except Exception as e:
|
| 115 |
+
print(f"读取进度文件失败:{e}")
|
| 116 |
+
|
| 117 |
+
final_samples = [] # 用于存储最终数据集样本信息
|
| 118 |
+
for i in tqdm(range(len(split_dataset)), desc=f"Processing {split_name}"):
|
| 119 |
+
if i < start_index:
|
| 120 |
+
# 如果样本已处理,则加载对应的 wav 文件路径(假设之前已经生成)并加入最终数据集
|
| 121 |
+
sample = split_dataset[i]
|
| 122 |
+
wav_path = os.path.join(split_output_dir, f"{i}.wav")
|
| 123 |
+
# 仅当文件存在时才加入最终数据集
|
| 124 |
+
if os.path.exists(wav_path):
|
| 125 |
+
final_samples.append({
|
| 126 |
+
"question_text": sample["question"],
|
| 127 |
+
"answer": sample["answer"],
|
| 128 |
+
"audio_filepath": wav_path
|
| 129 |
+
})
|
| 130 |
+
continue
|
| 131 |
+
|
| 132 |
+
sample = split_dataset[i]
|
| 133 |
+
# 处理 TTS 转换
|
| 134 |
+
result = process_example(sample, cosyvoice, common_voice, sample_rate=SAMPLE_RATE)
|
| 135 |
+
|
| 136 |
+
if result is not None:
|
| 137 |
+
# 确保 audio tensor shape 为 (channels, samples)
|
| 138 |
+
audio_tensor = result['audio_tensor']
|
| 139 |
+
if audio_tensor.dim() == 1:
|
| 140 |
+
audio_tensor = audio_tensor.unsqueeze(0)
|
| 141 |
+
sample_rate_val = result['sample_rate']
|
| 142 |
+
output_wav_path = os.path.join(split_output_dir, f"{i}.wav")
|
| 143 |
+
try:
|
| 144 |
+
torchaudio.save(output_wav_path, audio_tensor, sample_rate_val)
|
| 145 |
+
except Exception as e:
|
| 146 |
+
print(f"Failed to save wav for sample {i}: {e}")
|
| 147 |
+
continue
|
| 148 |
+
|
| 149 |
+
# 将转换后的样本信息保存到最终数据集中
|
| 150 |
+
final_samples.append({
|
| 151 |
+
"question_text": sample["question"],
|
| 152 |
+
"answer": sample["answer"],
|
| 153 |
+
"audio_filepath": output_wav_path
|
| 154 |
+
})
|
| 155 |
+
else:
|
| 156 |
+
print(f"Sample {i} processing failed, no audio generated.")
|
| 157 |
+
|
| 158 |
+
# 更新进度记录
|
| 159 |
+
with open(progress_file, "w") as f:
|
| 160 |
+
f.write(str(i + 1))
|
| 161 |
+
|
| 162 |
+
# 将当前 split 的最终数据集保存为 Hugging Face Dataset,并存盘
|
| 163 |
+
final_dataset_obj = Dataset.from_list(final_samples)
|
| 164 |
+
final_dataset_save_path = os.path.join(split_output_dir, "final_dataset")
|
| 165 |
+
final_dataset_obj.save_to_disk(final_dataset_save_path)
|
| 166 |
+
print(f"Finished processing split: {split_name} with {len(final_samples)} final samples.")
|
| 167 |
+
final_dataset_dict[split_name] = final_dataset_obj
|
| 168 |
+
|
| 169 |
+
print("所有分割处理完毕,最终数据集已保存。")
|
r1-a/dataset/gsm8k_with_audio/test/299.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:faf7e5efc632e11631a878b477523f9a010ad5e3bba9db05f3bc20cafadab95e
|
| 3 |
+
size 2277200
|
r1-a/dataset/gsm8k_with_audio/test/301.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5010c552eb3c63b659ef8d71af813db7706b55c302f9c9e1c17a831c8bc896a4
|
| 3 |
+
size 2346320
|
r1-a/dataset/gsm8k_with_audio/test/302.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:06e41bddcccc46feb3fd116d32c245ef65cf0434acda1313d020d286a947ddcd
|
| 3 |
+
size 917840
|
r1-a/dataset/gsm8k_with_audio/test/314.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:026282fcebd0c6ba5bef04291128bbc5dd82eadeb2c020faa0a53135477a7e19
|
| 3 |
+
size 1332560
|
r1-a/dataset/gsm8k_with_audio/test/316.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:83a29e59fc73ac1eb3b49033da596e47141d5911ceae61effc736b4b307a17de
|
| 3 |
+
size 1505360
|
r1-a/dataset/gsm8k_with_audio/test/350.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f5fd7f8dbaf220cd8fe66a151c056e384e07b13cf76778983bbc605043e43d4d
|
| 3 |
+
size 1946960
|
r1-a/dataset/gsm8k_with_audio/test/358.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9a9416095b31c0d238bb1321e6108733973fa651cfcee1fca8f898d18470a89b
|
| 3 |
+
size 1363280
|
r1-a/dataset/gsm8k_with_audio/test/359.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c1ba8571451ad8a8bce5fdf235517e0aae650b7a42d9cc763e92adc7e9c949de
|
| 3 |
+
size 1532240
|
r1-a/dataset/gsm8k_with_audio/test/369.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4365c6f3f261264849987b3a13a1d552c1004977c0afea259b9e29dad8fee08c
|
| 3 |
+
size 1866320
|
r1-a/dataset/gsm8k_with_audio/test/372.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a22d041239c1485f42ac79c20fd1ae2a9174fef598e36030c516c882663eae3c
|
| 3 |
+
size 1251920
|
r1-a/dataset/gsm8k_with_audio/test/376.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:65160b2855b4f25f29a7e0f4b5da970cba1e60ae1f1e637e5adfd7495c896f79
|
| 3 |
+
size 1198160
|
r1-a/dataset/gsm8k_with_audio/test/385.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3bbcea803a885d3d5cc8b17f23837a1820f119f9564006023f8781b96a2114f4
|
| 3 |
+
size 2165840
|
r1-a/dataset/gsm8k_with_audio/test/394.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ec4937195a9d64a026cd352a09943a56ee64344fedda73cb792d4d12b466fe3f
|
| 3 |
+
size 1344080
|
r1-a/dataset/gsm8k_with_audio/test/395.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7f84925ad72d43ff5b7bf24aa71472f08bc0ff5b43a09ed92929848451401ff0
|
| 3 |
+
size 1804880
|
r1-a/dataset/gsm8k_with_audio/test/397.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cc7c4027e38fd2090b026f8112151fd564d5a80df9ff230dbac69cb082f60c59
|
| 3 |
+
size 1056080
|
r1-a/dataset/gsm8k_with_audio/test/400.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cf0dace278147001a9b14f5701f5d206039b26438e50f19bcaf6015b46dab78e
|
| 3 |
+
size 867920
|
r1-a/dataset/gsm8k_with_audio/test/401.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a494472207f4c7f2d067dfa1ce3c19d18702b8fca5385aa2f23b7b7a9a869767
|
| 3 |
+
size 940880
|
r1-a/dataset/gsm8k_with_audio/test/447.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6f08dd8d815e539e2d6b2ea2c06320784e3f6b3f862e3d50e20a2c11ac38785b
|
| 3 |
+
size 1586000
|
r1-a/dataset/gsm8k_with_audio/test/45.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9a65adddc9e7cb0a684486badfd68a41610e2acd64e30b8e8a7da76607586092
|
| 3 |
+
size 2112080
|
r1-a/dataset/gsm8k_with_audio/test/450.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e42276c81767cb468b27f095f71e5a2765272a753ee9d69cb399998561ee3348
|
| 3 |
+
size 2104400
|
r1-a/dataset/gsm8k_with_audio/test/454.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c94820c5d8bdd41bdaf461ca950650950c0e44fa2427cb86c9ae59bfc8f781fc
|
| 3 |
+
size 599120
|
r1-a/dataset/gsm8k_with_audio/test/457.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3e3aaab17db00c2275f1ac1ebb8d2c710f4e7393488ca0a4fcdb45d57b6da2d9
|
| 3 |
+
size 1774160
|
r1-a/dataset/gsm8k_with_audio/test/458.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:813d8661f38ffba540960d4c50dd877714a196acd0d9fa1d2f9a9285ab35b40e
|
| 3 |
+
size 1082960
|
r1-a/dataset/gsm8k_with_audio/test/459.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:48f55b9dbcda81e8e98ac58c7e428de7ed472f79d4dbbf763d8f6232d53b3e00
|
| 3 |
+
size 2945360
|
r1-a/dataset/gsm8k_with_audio/test/463.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3c0710e41b29894e6e2d4519a60c00ccdba41584f0a8fa1e9c4800bf69d5b63f
|
| 3 |
+
size 1298000
|
r1-a/dataset/gsm8k_with_audio/test/465.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ccdf5bacc8e5c5c24f086471ad15811f56e355f2814963aaa2f2c94ce916da05
|
| 3 |
+
size 1747280
|
r1-a/dataset/gsm8k_with_audio/test/515.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:48ea6967d2fa10e62b5e74b695650c9cf4f5a06d4a232937627e75f35b93e526
|
| 3 |
+
size 1278800
|
r1-a/dataset/gsm8k_with_audio/test/877.wav
ADDED
|
Binary file (38.5 kB). View file
|
|
|
r1-a/dataset/gsm8k_with_audio/test/964.wav
ADDED
|
Binary file (88.4 kB). View file
|
|
|
r1-a/dataset/gsm8k_with_audio/test/final_dataset/dataset_info.json
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"citation": "",
|
| 3 |
+
"description": "",
|
| 4 |
+
"features": {
|
| 5 |
+
"question_text": {
|
| 6 |
+
"dtype": "string",
|
| 7 |
+
"_type": "Value"
|
| 8 |
+
},
|
| 9 |
+
"answer": {
|
| 10 |
+
"dtype": "string",
|
| 11 |
+
"_type": "Value"
|
| 12 |
+
},
|
| 13 |
+
"audio_filepath": {
|
| 14 |
+
"dtype": "string",
|
| 15 |
+
"_type": "Value"
|
| 16 |
+
}
|
| 17 |
+
},
|
| 18 |
+
"homepage": "",
|
| 19 |
+
"license": ""
|
| 20 |
+
}
|
r1-a/dataset/gsm8k_with_audio/test/final_dataset/state.json
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_data_files": [
|
| 3 |
+
{
|
| 4 |
+
"filename": "data-00000-of-00001.arrow"
|
| 5 |
+
}
|
| 6 |
+
],
|
| 7 |
+
"_fingerprint": "a037f9a8bcdfc025",
|
| 8 |
+
"_format_columns": null,
|
| 9 |
+
"_format_kwargs": {},
|
| 10 |
+
"_format_type": null,
|
| 11 |
+
"_output_all_columns": false,
|
| 12 |
+
"_split": null
|
| 13 |
+
}
|
r1-a/dataset/gsm8k_with_audio/test/progress.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
1319
|
r1-a/dataset/pkusafe.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from collections import defaultdict
|
| 3 |
+
from datasets import load_dataset, Dataset # 确保导入 Dataset
|
| 4 |
+
from tqdm.auto import tqdm
|
| 5 |
+
import traceback
|
| 6 |
+
import os # 需要导入 os 来创建目录
|
| 7 |
+
|
| 8 |
+
# --- 配置参数 ---
|
| 9 |
+
DATASET_NAME = "PKU-Alignment/PKU-SafeRLHF"
|
| 10 |
+
SPLIT_TO_PROCESS = "train" # 或者 'test' 等
|
| 11 |
+
OUTPUT_DATASET_DIR = "pku_saferlhf_filtered_unsafe_diverse_hf" # 输出目录名
|
| 12 |
+
|
| 13 |
+
# --- 脚本主逻辑 (与之前版本相同) ---
|
| 14 |
+
|
| 15 |
+
def get_true_harm_categories(harm_category_dict):
|
| 16 |
+
"""从 harm_category 字典中提取值为 True 的键(类别名称)"""
|
| 17 |
+
if not isinstance(harm_category_dict, dict):
|
| 18 |
+
return []
|
| 19 |
+
return [category for category, is_present in harm_category_dict.items() if is_present]
|
| 20 |
+
|
| 21 |
+
def filter_pku_saferlhf_detailed(dataset_name: str, split: str):
|
| 22 |
+
"""
|
| 23 |
+
加载、过滤 PKU-SafeRLHF 数据集,确保包含不安全回答,提取涉及的 harm category 名称,
|
| 24 |
+
并尽可能覆盖所有原始存在的 harm category 名称。
|
| 25 |
+
"""
|
| 26 |
+
print(f"加载数据集: {dataset_name}, split: {split}...")
|
| 27 |
+
try:
|
| 28 |
+
dataset = load_dataset(dataset_name, split=split)
|
| 29 |
+
print("数据集加载完成。")
|
| 30 |
+
except Exception as e:
|
| 31 |
+
print(f"错误:无法加载数据集 {dataset_name} (split: {split})。请检查名称和可用 split。")
|
| 32 |
+
print(f"错误详情: {e}")
|
| 33 |
+
return None, set()
|
| 34 |
+
|
| 35 |
+
# --- 步骤 1: 识别原始数据集中所有出现过的 harm_category 名称 ---
|
| 36 |
+
print("正在识别所有原始 harm_category 名称...")
|
| 37 |
+
all_harm_category_names_in_split = set()
|
| 38 |
+
scan_errors = 0
|
| 39 |
+
for example in tqdm(dataset, desc="扫描 harm_category"):
|
| 40 |
+
try:
|
| 41 |
+
for harm_dict_key in ["response_0_harm_category", "response_1_harm_category"]:
|
| 42 |
+
harm_dict = example.get(harm_dict_key)
|
| 43 |
+
if isinstance(harm_dict, dict):
|
| 44 |
+
true_categories = get_true_harm_categories(harm_dict)
|
| 45 |
+
all_harm_category_names_in_split.update(true_categories)
|
| 46 |
+
except Exception as e:
|
| 47 |
+
scan_errors += 1
|
| 48 |
+
continue
|
| 49 |
+
|
| 50 |
+
if scan_errors > 0:
|
| 51 |
+
print(f"警告:在扫描 harm_category 阶段,有 {scan_errors} 个样本处理出错。")
|
| 52 |
+
if not all_harm_category_names_in_split:
|
| 53 |
+
print("错误:未能从数据集中提取任何 harm_category 名称。请检查数据集结构或内容。")
|
| 54 |
+
return None, set()
|
| 55 |
+
print(f"原始数据集中找到 {len(all_harm_category_names_in_split)} 个唯一的 harm_category 名称。")
|
| 56 |
+
|
| 57 |
+
# --- 步骤 2: 筛选至少有一个不安全回答的样本,并提取涉及的类别 ---
|
| 58 |
+
print("正在筛选包含不安全回答的样本并提取涉及的类别...")
|
| 59 |
+
filtered_data_list = [] # 改名为 list 以区分最终的 Dataset 对象
|
| 60 |
+
covered_categories_in_filtered = set()
|
| 61 |
+
processed_count = 0
|
| 62 |
+
filtered_count = 0
|
| 63 |
+
filter_errors = 0
|
| 64 |
+
|
| 65 |
+
for example in tqdm(dataset, desc="筛选不安全样本"):
|
| 66 |
+
processed_count += 1
|
| 67 |
+
try:
|
| 68 |
+
prompt = example.get("prompt")
|
| 69 |
+
response_0 = example.get("response_0")
|
| 70 |
+
response_1 = example.get("response_1")
|
| 71 |
+
is_safe_0 = example.get("is_response_0_safe")
|
| 72 |
+
is_safe_1 = example.get("is_response_1_safe")
|
| 73 |
+
response_0_harm_dict = example.get("response_0_harm_category")
|
| 74 |
+
response_1_harm_dict = example.get("response_1_harm_category")
|
| 75 |
+
|
| 76 |
+
if not all([prompt, response_0, response_1,
|
| 77 |
+
is_safe_0 is not None, is_safe_1 is not None]):
|
| 78 |
+
continue
|
| 79 |
+
|
| 80 |
+
if not is_safe_0 or not is_safe_1:
|
| 81 |
+
involved_harm_categories = set()
|
| 82 |
+
|
| 83 |
+
if not is_safe_0:
|
| 84 |
+
categories_0 = get_true_harm_categories(response_0_harm_dict)
|
| 85 |
+
involved_harm_categories.update(categories_0)
|
| 86 |
+
|
| 87 |
+
if not is_safe_1:
|
| 88 |
+
categories_1 = get_true_harm_categories(response_1_harm_dict)
|
| 89 |
+
involved_harm_categories.update(categories_1)
|
| 90 |
+
|
| 91 |
+
# (可选过滤逻辑)
|
| 92 |
+
# if not involved_harm_categories:
|
| 93 |
+
# continue
|
| 94 |
+
|
| 95 |
+
filtered_sample = {
|
| 96 |
+
"prompt": prompt,
|
| 97 |
+
"response_0": response_0,
|
| 98 |
+
"response_1": response_1,
|
| 99 |
+
"is_safe_0": is_safe_0,
|
| 100 |
+
"is_safe_1": is_safe_1,
|
| 101 |
+
# **** 注意:Dataset 对象对于 list of strings 的支持更好 ****
|
| 102 |
+
"involved_harm_categories": sorted(list(involved_harm_categories)),
|
| 103 |
+
"better_response_id": example.get("better_response_id"),
|
| 104 |
+
"safer_response_id": example.get("safer_response_id"),
|
| 105 |
+
# 可以根据需要添加其他字段
|
| 106 |
+
}
|
| 107 |
+
filtered_data_list.append(filtered_sample)
|
| 108 |
+
covered_categories_in_filtered.update(involved_harm_categories)
|
| 109 |
+
filtered_count += 1
|
| 110 |
+
|
| 111 |
+
except Exception as e:
|
| 112 |
+
filter_errors += 1
|
| 113 |
+
continue
|
| 114 |
+
|
| 115 |
+
if filter_errors > 0:
|
| 116 |
+
print(f"警告:在筛选阶段,有 {filter_errors} 个样本处理出错。")
|
| 117 |
+
print(f"筛选完成。共处理 {processed_count} 个样本,筛选出 {filtered_count} 个符合条件的样本。")
|
| 118 |
+
|
| 119 |
+
# --- 步骤 3: 检查 harm_category 覆盖情况 ---
|
| 120 |
+
missing_categories = all_harm_category_names_in_split - covered_categories_in_filtered
|
| 121 |
+
if missing_categories:
|
| 122 |
+
print(f"\n警告:以下 harm_category 名称存在于原始数据集中,但在筛选出的不安全样本中未能找到对应的类别: {missing_categories}")
|
| 123 |
+
else:
|
| 124 |
+
print("\n好消息!所有原始数据集中存在的 harm_category 名称都已在筛选后的数据中得到覆盖。")
|
| 125 |
+
print(f"最终数据集包含 {len(filtered_data_list)} 个样本,覆盖 {len(covered_categories_in_filtered)} 个 harm_category 名称。")
|
| 126 |
+
|
| 127 |
+
return filtered_data_list, covered_categories_in_filtered # 返回 list 和 set
|
| 128 |
+
|
| 129 |
+
# --- 主程序 ---
|
| 130 |
+
if __name__ == "__main__":
|
| 131 |
+
# 执行过滤
|
| 132 |
+
filtered_list, final_categories = filter_pku_saferlhf_detailed(DATASET_NAME, SPLIT_TO_PROCESS)
|
| 133 |
+
|
| 134 |
+
if filtered_list: # 检查返回的列表是否非空
|
| 135 |
+
print(f"\n将 {len(filtered_list)} 条过滤后的数据保存为 Hugging Face Dataset 格式...")
|
| 136 |
+
print(f"目标目录: {OUTPUT_DATASET_DIR}")
|
| 137 |
+
|
| 138 |
+
try:
|
| 139 |
+
# 1. 将 list of dicts 转换为 Dataset 对象
|
| 140 |
+
# Hugging Face 会自动推断列类型。 involved_harm_categories 是 list of strings。
|
| 141 |
+
hf_dataset = Dataset.from_list(filtered_list)
|
| 142 |
+
|
| 143 |
+
# 2. 保存到磁盘
|
| 144 |
+
if not os.path.exists(OUTPUT_DATASET_DIR):
|
| 145 |
+
os.makedirs(OUTPUT_DATASET_DIR)
|
| 146 |
+
print(f"已创建目录: {OUTPUT_DATASET_DIR}")
|
| 147 |
+
|
| 148 |
+
hf_dataset.save_to_disk(OUTPUT_DATASET_DIR)
|
| 149 |
+
print(f"\n数据集成功保存到目录: {OUTPUT_DATASET_DIR}")
|
| 150 |
+
print(f"你可以使用以下代码加载它:")
|
| 151 |
+
print(f"from datasets import load_from_disk")
|
| 152 |
+
print(f"loaded_dataset = load_from_disk('{OUTPUT_DATASET_DIR}')")
|
| 153 |
+
print(f"\n最终数据集覆盖的 harm_category 名称: {final_categories}")
|
| 154 |
+
|
| 155 |
+
# 打印一些样本看看 (从 Dataset 对象加载)
|
| 156 |
+
print("\n部分样本预览 (从保存的 Dataset 加载):")
|
| 157 |
+
loaded_dataset = Dataset.load_from_disk(OUTPUT_DATASET_DIR) # 加载回来验证
|
| 158 |
+
for i in range(min(5, len(loaded_dataset))):
|
| 159 |
+
sample = loaded_dataset[i]
|
| 160 |
+
print(f"--- 样本 {i+1} ---")
|
| 161 |
+
print(f"Prompt: {sample['prompt'][:150]}...")
|
| 162 |
+
print(f"Response 0 (is_safe={sample['is_safe_0']}): {sample['response_0'][:100]}...")
|
| 163 |
+
print(f"Response 1 (is_safe={sample['is_safe_1']}): {sample['response_1'][:100]}...")
|
| 164 |
+
print(f"Involved Harm Categories: {sample['involved_harm_categories']}")
|
| 165 |
+
|
| 166 |
+
except Exception as e:
|
| 167 |
+
print(f"\n错误:保存 Hugging Face Dataset 时出错: {e}")
|
| 168 |
+
traceback.print_exc() # 打印详细错误信息
|
| 169 |
+
|
| 170 |
+
else:
|
| 171 |
+
print("\n未能筛选出任何符合条件的样本,或在加载/处理数据时发生严重错误。未保存任何内容。")
|
r1-a/dataset/pkusafe_tts.py
ADDED
|
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import random
|
| 3 |
+
import torch
|
| 4 |
+
import torchaudio
|
| 5 |
+
# Import load_from_disk to load the dataset saved by your first script
|
| 6 |
+
from datasets import load_dataset, Dataset, load_from_disk
|
| 7 |
+
import sys
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
import time # Import time for potential delays between retries
|
| 10 |
+
|
| 11 |
+
sys.path.append('/root/autodl-tmp/CosyVoice') # Make sure this path is correct for your environment
|
| 12 |
+
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
|
| 13 |
+
from cosyvoice.utils.file_utils import load_wav
|
| 14 |
+
|
| 15 |
+
# ------------------------
|
| 16 |
+
# 配置参数
|
| 17 |
+
# ------------------------
|
| 18 |
+
COMMON_VOICE_LANGUAGE = "en"
|
| 19 |
+
# --- Path to the pre-filtered dataset saved by your FIRST script ---
|
| 20 |
+
FILTERED_DATASET_PATH = "pku_saferlhf_filtered_unsafe_diverse_hf" # <-- IMPORTANT: Make sure this matches the OUTPUT_DATASET_DIR from your first script
|
| 21 |
+
# --- Output path for THIS TTS script ---
|
| 22 |
+
OUTPUT_DATASET_PATH = './pku_saferlhf_filtered_with_audio' # <-- New output path for the dataset with audio
|
| 23 |
+
SAMPLE_RATE = 16000
|
| 24 |
+
MAX_TTS_RETRIES = 3 # Maximum number of TTS attempts per query
|
| 25 |
+
RETRY_DELAY_SECONDS = 2 # Optional delay between retries
|
| 26 |
+
|
| 27 |
+
# ------------------------
|
| 28 |
+
# 辅助函数 (Identical to the previous version with retry logic)
|
| 29 |
+
# ------------------------
|
| 30 |
+
def get_random_prompt(common_voice_dataset, sample_rate=SAMPLE_RATE):
|
| 31 |
+
"""
|
| 32 |
+
从 VoxPopuli (替代原 Common Voice) 数据集中随机抽取一条语音及对应文本作为 prompt。
|
| 33 |
+
"""
|
| 34 |
+
idx = random.randint(0, len(common_voice_dataset) - 1)
|
| 35 |
+
sample = common_voice_dataset.select([idx])[0]
|
| 36 |
+
audio = sample['audio']
|
| 37 |
+
waveform = torch.tensor(audio['array'], dtype=torch.float32)
|
| 38 |
+
sr = audio['sampling_rate']
|
| 39 |
+
if sr != sample_rate:
|
| 40 |
+
if waveform.dim() > 1:
|
| 41 |
+
waveform = waveform.mean(dim=0)
|
| 42 |
+
resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate)
|
| 43 |
+
waveform = resampler(waveform)
|
| 44 |
+
if waveform.dim() == 1:
|
| 45 |
+
waveform = waveform.unsqueeze(0)
|
| 46 |
+
if waveform.numel() == 0 or not sample['raw_text']:
|
| 47 |
+
print("Warning: Got an empty prompt, trying again...")
|
| 48 |
+
return get_random_prompt(common_voice_dataset, sample_rate)
|
| 49 |
+
return waveform, sample['raw_text']
|
| 50 |
+
|
| 51 |
+
def text_to_audio(query_text, cosyvoice, common_voice_dataset, stream=False, max_retries=MAX_TTS_RETRIES):
|
| 52 |
+
"""
|
| 53 |
+
利用 CosyVoice2 模型将输入文本转换为语音,采用随机 prompt 进行零样本推理。
|
| 54 |
+
Includes retry logic on failure.
|
| 55 |
+
"""
|
| 56 |
+
last_exception = None
|
| 57 |
+
for attempt in range(max_retries):
|
| 58 |
+
try:
|
| 59 |
+
prompt_speech, prompt_text = get_random_prompt(common_voice_dataset, sample_rate=SAMPLE_RATE)
|
| 60 |
+
|
| 61 |
+
all_speech = []
|
| 62 |
+
inference_generator = cosyvoice.inference_zero_shot(
|
| 63 |
+
query_text,
|
| 64 |
+
prompt_text,
|
| 65 |
+
prompt_speech,
|
| 66 |
+
stream=stream,
|
| 67 |
+
text_frontend=False
|
| 68 |
+
)
|
| 69 |
+
for i, chunk in enumerate(inference_generator):
|
| 70 |
+
if 'tts_speech' in chunk and chunk['tts_speech'] is not None:
|
| 71 |
+
all_speech.append(chunk['tts_speech'])
|
| 72 |
+
else:
|
| 73 |
+
print(f"Warning: Chunk {i} missing 'tts_speech' for text '{query_text[:60]}...'")
|
| 74 |
+
|
| 75 |
+
if not all_speech:
|
| 76 |
+
raise ValueError("TTS inference finished but produced no audio chunks.")
|
| 77 |
+
|
| 78 |
+
combined_speech = torch.cat(all_speech, dim=-1)
|
| 79 |
+
sample_rate_val = cosyvoice.sample_rate
|
| 80 |
+
|
| 81 |
+
return {
|
| 82 |
+
'audio_tensor': combined_speech,
|
| 83 |
+
'sample_rate': sample_rate_val
|
| 84 |
+
}
|
| 85 |
+
except Exception as e:
|
| 86 |
+
last_exception = e
|
| 87 |
+
print(f"Error converting text to audio on attempt {attempt + 1}/{max_retries}: {e}")
|
| 88 |
+
print(f"Text: '{query_text[:100]}...'")
|
| 89 |
+
print(f"Prompt Text: '{prompt_text[:100]}...'")
|
| 90 |
+
if attempt < max_retries - 1:
|
| 91 |
+
print(f"Retrying with a different prompt in {RETRY_DELAY_SECONDS}s...")
|
| 92 |
+
time.sleep(RETRY_DELAY_SECONDS)
|
| 93 |
+
else:
|
| 94 |
+
print(f"All {max_retries} TTS attempts failed.")
|
| 95 |
+
|
| 96 |
+
print(f"Failed to generate audio for text after {max_retries} attempts: '{query_text[:60]}...'")
|
| 97 |
+
print(f"Last error: {last_exception}")
|
| 98 |
+
return None
|
| 99 |
+
|
| 100 |
+
def process_example(example, cosyvoice, common_voice_dataset, sample_rate=SAMPLE_RATE):
|
| 101 |
+
"""
|
| 102 |
+
针对从磁盘加载的 PKU-SafeRLHF 过滤后数据集中的单个样本进行 TTS 处理。
|
| 103 |
+
Processes example['prompt']. <--- Changed from 'query'/'question'
|
| 104 |
+
"""
|
| 105 |
+
# --- Target the 'prompt' field from the filtered PKU-SafeRLHF dataset ---
|
| 106 |
+
query = example.get('prompt') # <--- Use 'prompt' field
|
| 107 |
+
if not query or not isinstance(query, str) or query.strip() == "":
|
| 108 |
+
print(f"Warning: Skipping example due to missing or empty 'prompt' field: {example.keys()}") # Log keys if prompt is missing
|
| 109 |
+
return None
|
| 110 |
+
|
| 111 |
+
# --- Use the text_to_audio function with retry logic ---
|
| 112 |
+
audio_result = text_to_audio(query, cosyvoice, common_voice_dataset, stream=False)
|
| 113 |
+
|
| 114 |
+
if audio_result is not None:
|
| 115 |
+
audio_tensor = audio_result['audio_tensor']
|
| 116 |
+
if audio_tensor.dim() == 1:
|
| 117 |
+
audio_tensor = audio_tensor.unsqueeze(0)
|
| 118 |
+
elif audio_tensor.dim() > 2:
|
| 119 |
+
print(f"Warning: Generated audio tensor has unexpected shape {audio_tensor.shape}. Attempting to flatten.")
|
| 120 |
+
audio_tensor = audio_tensor.view(1, -1)
|
| 121 |
+
|
| 122 |
+
if audio_tensor.numel() == 0:
|
| 123 |
+
print(f"Warning: Generated audio tensor is empty for prompt: '{query[:60]}...'")
|
| 124 |
+
return None
|
| 125 |
+
|
| 126 |
+
return {
|
| 127 |
+
'audio_tensor': audio_tensor,
|
| 128 |
+
'sample_rate': audio_result['sample_rate']
|
| 129 |
+
}
|
| 130 |
+
else:
|
| 131 |
+
return None
|
| 132 |
+
|
| 133 |
+
# ------------------------
|
| 134 |
+
# 数据加载与模型初始化
|
| 135 |
+
# ------------------------
|
| 136 |
+
print("Loading VoxPopuli (as Common Voice) dataset for prompts...")
|
| 137 |
+
try:
|
| 138 |
+
common_voice = load_dataset("facebook/voxpopuli", "en", split='train')
|
| 139 |
+
print(f"Total VoxPopuli {COMMON_VOICE_LANGUAGE} samples: {len(common_voice)}")
|
| 140 |
+
if len(common_voice) == 0:
|
| 141 |
+
raise ValueError("VoxPopuli dataset loaded but contains no samples.")
|
| 142 |
+
except Exception as e:
|
| 143 |
+
print(f"Error loading VoxPopuli dataset: {e}")
|
| 144 |
+
sys.exit(1)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
print("Initializing CosyVoice2 model...")
|
| 148 |
+
try:
|
| 149 |
+
cosyvoice = CosyVoice2(
|
| 150 |
+
'/root/autodl-tmp/CosyVoice/pretrained_models/CosyVoice2-0.5B', # Verify this path is correct
|
| 151 |
+
load_jit=True,
|
| 152 |
+
load_trt=False,
|
| 153 |
+
fp16=False
|
| 154 |
+
)
|
| 155 |
+
except Exception as e:
|
| 156 |
+
print(f"Error initializing CosyVoice2 model: {e}")
|
| 157 |
+
sys.exit(1)
|
| 158 |
+
|
| 159 |
+
print(f"Loading pre-filtered PKU-SafeRLHF dataset from disk: {FILTERED_DATASET_PATH}")
|
| 160 |
+
try:
|
| 161 |
+
# --- Load the dataset saved by your first script ---
|
| 162 |
+
filtered_dataset = load_from_disk(FILTERED_DATASET_PATH)
|
| 163 |
+
if not filtered_dataset:
|
| 164 |
+
raise ValueError(f"Dataset loaded from '{FILTERED_DATASET_PATH}' is empty or invalid.")
|
| 165 |
+
print(f"Successfully loaded dataset with {len(filtered_dataset)} examples.")
|
| 166 |
+
# --- Assume the loaded dataset corresponds to a single split (e.g., 'train') ---
|
| 167 |
+
# Wrap it in a dictionary to match the structure expected by the loop below
|
| 168 |
+
dataset_dict = {"train": filtered_dataset} # Use "train" as the key, matching the split processed by the filter script
|
| 169 |
+
# Alternatively, if you know the split name used in the filter script was different, use that name.
|
| 170 |
+
except FileNotFoundError:
|
| 171 |
+
print(f"Error: Pre-filtered dataset not found at '{FILTERED_DATASET_PATH}'.")
|
| 172 |
+
print("Please ensure the first script ran successfully and saved the data to the correct location.")
|
| 173 |
+
sys.exit(1)
|
| 174 |
+
except Exception as e:
|
| 175 |
+
print(f"Error loading pre-filtered dataset from '{FILTERED_DATASET_PATH}': {e}")
|
| 176 |
+
sys.exit(1)
|
| 177 |
+
|
| 178 |
+
# 创建输出目录
|
| 179 |
+
os.makedirs(OUTPUT_DATASET_PATH, exist_ok=True)
|
| 180 |
+
|
| 181 |
+
# ------------------------
|
| 182 |
+
# 主处理循环
|
| 183 |
+
# ------------------------
|
| 184 |
+
final_dataset_dict = {} # 存放各 split 最终处理后的数据
|
| 185 |
+
|
| 186 |
+
# Iterate through the splits defined in dataset_dict (should just be 'train' in this case)
|
| 187 |
+
for split_name, split_dataset in dataset_dict.items():
|
| 188 |
+
print(f"Processing split: {split_name} with {len(split_dataset)} examples")
|
| 189 |
+
split_output_dir = os.path.join(OUTPUT_DATASET_PATH, split_name)
|
| 190 |
+
os.makedirs(split_output_dir, exist_ok=True)
|
| 191 |
+
|
| 192 |
+
# 用于断点续跑的进度记录
|
| 193 |
+
progress_file = os.path.join(split_output_dir, "progress.txt")
|
| 194 |
+
start_index = 0
|
| 195 |
+
if os.path.exists(progress_file):
|
| 196 |
+
try:
|
| 197 |
+
with open(progress_file, "r") as f:
|
| 198 |
+
content = f.read().strip()
|
| 199 |
+
if content:
|
| 200 |
+
start_index = int(content)
|
| 201 |
+
print(f"Resuming split '{split_name}' from sample index {start_index}")
|
| 202 |
+
else:
|
| 203 |
+
print(f"Progress file '{progress_file}' is empty, starting from index 0.")
|
| 204 |
+
start_index = 0
|
| 205 |
+
except ValueError:
|
| 206 |
+
print(f"Could not parse integer from progress file '{progress_file}'. Starting from index 0.")
|
| 207 |
+
start_index = 0
|
| 208 |
+
except Exception as e:
|
| 209 |
+
print(f"Error reading progress file '{progress_file}': {e}. Starting from index 0.")
|
| 210 |
+
start_index = 0
|
| 211 |
+
|
| 212 |
+
final_samples = []
|
| 213 |
+
|
| 214 |
+
# 遍历处理每条样本
|
| 215 |
+
pbar = tqdm(range(start_index, len(split_dataset)), desc=f"Processing {split_name}", initial=start_index, total=len(split_dataset))
|
| 216 |
+
for i in pbar:
|
| 217 |
+
sample = split_dataset[i]
|
| 218 |
+
output_wav_path = os.path.join(split_output_dir, f"{i}.wav")
|
| 219 |
+
|
| 220 |
+
if os.path.exists(output_wav_path):
|
| 221 |
+
sample_dict = {k: sample[k] for k in sample.keys()}
|
| 222 |
+
sample_dict["audio_filepath"] = output_wav_path
|
| 223 |
+
final_samples.append(sample_dict)
|
| 224 |
+
with open(progress_file, "w") as f:
|
| 225 |
+
f.write(str(i + 1))
|
| 226 |
+
continue
|
| 227 |
+
|
| 228 |
+
# --- Perform TTS on the 'prompt' field ---
|
| 229 |
+
result = process_example(sample, cosyvoice, common_voice, sample_rate=SAMPLE_RATE)
|
| 230 |
+
|
| 231 |
+
if result is not None:
|
| 232 |
+
audio_tensor = result['audio_tensor']
|
| 233 |
+
sample_rate_val = result['sample_rate']
|
| 234 |
+
|
| 235 |
+
try:
|
| 236 |
+
audio_tensor_save = audio_tensor.detach().cpu().to(torch.float32)
|
| 237 |
+
if audio_tensor_save.dim() == 1:
|
| 238 |
+
audio_tensor_save = audio_tensor_save.unsqueeze(0)
|
| 239 |
+
elif audio_tensor_save.dim() > 2:
|
| 240 |
+
audio_tensor_save = audio_tensor_save.view(1, -1)
|
| 241 |
+
|
| 242 |
+
torchaudio.save(output_wav_path, audio_tensor_save, sample_rate_val)
|
| 243 |
+
|
| 244 |
+
# Preserve all original fields from the filtered dataset + add audio path
|
| 245 |
+
sample_dict = {k: sample[k] for k in sample.keys()}
|
| 246 |
+
sample_dict["audio_filepath"] = output_wav_path
|
| 247 |
+
final_samples.append(sample_dict)
|
| 248 |
+
|
| 249 |
+
except Exception as e:
|
| 250 |
+
print(f"Failed to save wav for sample {i} at {output_wav_path}: {e}")
|
| 251 |
+
else:
|
| 252 |
+
print(f"Sample {i} processing failed after retries (Prompt: '{sample.get('prompt', 'N/A')[:60]}...'), no audio generated.")
|
| 253 |
+
|
| 254 |
+
# Update progress file after processing each sample
|
| 255 |
+
with open(progress_file, "w") as f:
|
| 256 |
+
f.write(str(i + 1))
|
| 257 |
+
|
| 258 |
+
# Generate Hugging Face Dataset from the collected successful samples and save
|
| 259 |
+
if final_samples:
|
| 260 |
+
final_dataset_obj = Dataset.from_list(final_samples)
|
| 261 |
+
final_dataset_save_path = os.path.join(split_output_dir, "final_dataset")
|
| 262 |
+
try:
|
| 263 |
+
print(f"Saving final dataset for split '{split_name}' to {final_dataset_save_path}...")
|
| 264 |
+
final_dataset_obj.save_to_disk(final_dataset_save_path)
|
| 265 |
+
print(f"Finished processing split: {split_name}. Saved {len(final_samples)} samples with audio paths.")
|
| 266 |
+
final_dataset_dict[split_name] = final_dataset_obj
|
| 267 |
+
except Exception as e:
|
| 268 |
+
print(f"Error saving final dataset for split '{split_name}' to disk: {e}")
|
| 269 |
+
else:
|
| 270 |
+
print(f"Finished processing split: {split_name}. No samples were successfully processed or saved.")
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
print("="*30)
|
| 274 |
+
if final_dataset_dict:
|
| 275 |
+
print(f"All specified splits processed. Final datasets saved in respective subdirectories within '{OUTPUT_DATASET_PATH}'.")
|
| 276 |
+
print(f"Processed splits: {list(final_dataset_dict.keys())}")
|
| 277 |
+
else:
|
| 278 |
+
print(f"Processing finished, but no final datasets were generated or saved in '{OUTPUT_DATASET_PATH}'. Check logs for errors.")
|
| 279 |
+
print("="*30)
|
r1-a/dataset/retry_rewrite.py
ADDED
|
@@ -0,0 +1,442 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import http.client
|
| 3 |
+
import json
|
| 4 |
+
import time
|
| 5 |
+
import random
|
| 6 |
+
# Import necessary types from datasets
|
| 7 |
+
from datasets import load_dataset, Dataset, DatasetDict, Features, Value, Sequence
|
| 8 |
+
from tqdm.auto import tqdm
|
| 9 |
+
import sys
|
| 10 |
+
import logging
|
| 11 |
+
# Removed unused imports (like socket, already used by http.client indirectly)
|
| 12 |
+
# import socket # Not directly needed now
|
| 13 |
+
import concurrent.futures
|
| 14 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 15 |
+
import shutil # Needed for atomic directory removal
|
| 16 |
+
|
| 17 |
+
# --- Configuration ---
|
| 18 |
+
DATASET_NAME = "virtuoussy/Multi-subject-RLVR" # Or the original source if needed
|
| 19 |
+
DATASET_SPLIT = "train"
|
| 20 |
+
API_HOST = "api2.aigcbest.top"
|
| 21 |
+
API_PATH = "/v1/chat/completions"
|
| 22 |
+
LLM_MODEL = "gpt-4.1-mini"
|
| 23 |
+
API_KEY = os.environ.get('AIGCBEST_API_KEY', "sk-U15cDXxI0bboL6iH4Hymzl30ws6oWzazWe1Ndwq9QtiPUEgI")
|
| 24 |
+
if not API_KEY or API_KEY == "YOUR_API_KEY_HERE":
|
| 25 |
+
print("API Key is not set correctly. Please set the AIGCBEST_API_KEY environment variable or replace the placeholder.")
|
| 26 |
+
sys.exit(1)
|
| 27 |
+
|
| 28 |
+
OUTPUT_DIR = f"./{DATASET_NAME.split('/')[-1]}_rephrased"
|
| 29 |
+
# Path to the existing, potentially incomplete, processed dataset (LOAD ONLY)
|
| 30 |
+
PROCESSED_DATA_PATH = os.path.join(OUTPUT_DIR, f"{DATASET_SPLIT}_processed")
|
| 31 |
+
# Path where intermediate and final results will be saved (SAVE ONLY)
|
| 32 |
+
FINAL_OUTPUT_PATH = os.path.join(OUTPUT_DIR, f"{DATASET_SPLIT}_processed_final")
|
| 33 |
+
|
| 34 |
+
MAX_WORKERS = 20 # Adjust based on your system and API rate limits
|
| 35 |
+
REQUEST_DELAY_SECONDS = 0.15 # Base delay between requests
|
| 36 |
+
MAX_RETRIES = 3 # Max retries for each API call
|
| 37 |
+
SAVE_INTERVAL = 2000 # <<<--- How often to save progress (in number of processed items)
|
| 38 |
+
|
| 39 |
+
# Setup logging
|
| 40 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 41 |
+
logging.getLogger("datasets").setLevel(logging.WARNING)
|
| 42 |
+
logging.getLogger("huggingface_hub").setLevel(logging.WARNING)
|
| 43 |
+
logging.getLogger("filelock").setLevel(logging.WARNING) # Quiet down filelock warnings during save
|
| 44 |
+
|
| 45 |
+
# --- LLM API Function (call_llm_api) ---
|
| 46 |
+
# (No changes needed here, keep the robust version)
|
| 47 |
+
def call_llm_api(original_question, api_key, host, path, model, retries=MAX_RETRIES):
|
| 48 |
+
system_prompt = (
|
| 49 |
+
"You are an expert linguist specializing in converting structured prompts or "
|
| 50 |
+
"fill-in-the-blank problems into natural, spoken-language questions suitable for "
|
| 51 |
+
"text-to-speech (TTS). Your goal is to make the question sound like how a person "
|
| 52 |
+
"would naturally ask it. "
|
| 53 |
+
"If the input is a fill-in-the-blank problem (e.g., contains '-----'), "
|
| 54 |
+
"rephrase it as a direct question asking for the missing information. "
|
| 55 |
+
"Keep the core meaning, mathematical context, variables, and numbers exactly the same. "
|
| 56 |
+
"Focus only on rephrasing the *user's question* part provided. "
|
| 57 |
+
"Output *only* the rephrased question, without any introductory phrases like 'Here's the rephrased question:'."
|
| 58 |
+
)
|
| 59 |
+
payload = json.dumps({
|
| 60 |
+
"model": model,
|
| 61 |
+
"messages": [
|
| 62 |
+
{"role": "system", "content": system_prompt},
|
| 63 |
+
{"role": "user", "content": original_question}
|
| 64 |
+
],
|
| 65 |
+
})
|
| 66 |
+
headers = {
|
| 67 |
+
'Accept': 'application/json',
|
| 68 |
+
'Authorization': f'Bearer {api_key}',
|
| 69 |
+
'User-Agent': 'HuggingFace Dataset Processing Script (Retry w/ Save)',
|
| 70 |
+
'Content-Type': 'application/json'
|
| 71 |
+
}
|
| 72 |
+
time.sleep(random.uniform(REQUEST_DELAY_SECONDS * 0.8, REQUEST_DELAY_SECONDS * 1.2))
|
| 73 |
+
|
| 74 |
+
for attempt in range(retries):
|
| 75 |
+
# logging.debug(f"API call attempt {attempt + 1}/{retries} for: {original_question[:50]}...")
|
| 76 |
+
try:
|
| 77 |
+
conn = http.client.HTTPSConnection(host, timeout=60) # Increased timeout
|
| 78 |
+
conn.request("POST", path, payload, headers)
|
| 79 |
+
res = conn.getresponse()
|
| 80 |
+
status = res.status
|
| 81 |
+
data = res.read()
|
| 82 |
+
conn.close()
|
| 83 |
+
|
| 84 |
+
if status == 200:
|
| 85 |
+
response_json = json.loads(data.decode("utf-8"))
|
| 86 |
+
if response_json.get("choices") and len(response_json["choices"]) > 0:
|
| 87 |
+
message = response_json["choices"][0].get("message")
|
| 88 |
+
if message and message.get("content"):
|
| 89 |
+
rephrased = message["content"].strip()
|
| 90 |
+
# Remove surrounding quotes more robustly
|
| 91 |
+
if len(rephrased) > 1:
|
| 92 |
+
if (rephrased.startswith('"') and rephrased.endswith('"')) or \
|
| 93 |
+
(rephrased.startswith("'") and rephrased.endswith("'")):
|
| 94 |
+
rephrased = rephrased[1:-1]
|
| 95 |
+
# Handle cases like 'Rephrased: "..."'
|
| 96 |
+
if rephrased.lower().startswith(("rephrased:", "here's the rephrased question:")):
|
| 97 |
+
parts = rephrased.split(":", 1)
|
| 98 |
+
if len(parts) > 1:
|
| 99 |
+
potential_rephrased = parts[1].strip()
|
| 100 |
+
if (potential_rephrased.startswith('"') and potential_rephrased.endswith('"')) or \
|
| 101 |
+
(potential_rephrased.startswith("'") and potential_rephrased.endswith("'")):
|
| 102 |
+
rephrased = potential_rephrased[1:-1]
|
| 103 |
+
else:
|
| 104 |
+
rephrased = potential_rephrased
|
| 105 |
+
|
| 106 |
+
if rephrased and rephrased.strip().lower() != original_question.strip().lower():
|
| 107 |
+
# logging.debug(f"Successfully rephrased: {rephrased[:80]}...")
|
| 108 |
+
return rephrased
|
| 109 |
+
elif not rephrased:
|
| 110 |
+
logging.warning(f"LLM returned empty/whitespace response for: {original_question[:50]}...")
|
| 111 |
+
return None
|
| 112 |
+
else:
|
| 113 |
+
logging.warning(f"LLM returned identical response for: {original_question[:50]}...")
|
| 114 |
+
return None # Treat identical as failure
|
| 115 |
+
logging.error(f"Unexpected API response structure: {data.decode('utf-8')}")
|
| 116 |
+
return None
|
| 117 |
+
elif status == 429: # Rate limit
|
| 118 |
+
retry_after_header = res.getheader('Retry-After', '5')
|
| 119 |
+
try: wait_time = int(retry_after_header)
|
| 120 |
+
except ValueError: wait_time = REQUEST_DELAY_SECONDS * (2 ** attempt) + random.uniform(1, 5)
|
| 121 |
+
logging.warning(f"Rate limit exceeded (HTTP {status}). Retrying after {wait_time:.2f} seconds...")
|
| 122 |
+
time.sleep(wait_time)
|
| 123 |
+
elif status >= 500: # Server error
|
| 124 |
+
wait_time = REQUEST_DELAY_SECONDS * (2 ** attempt) + random.uniform(1, 5)
|
| 125 |
+
logging.warning(f"Server error (HTTP {status}). Retrying after {wait_time:.2f} seconds...")
|
| 126 |
+
time.sleep(wait_time)
|
| 127 |
+
else: # Other client errors (4xx) - Don't retry these
|
| 128 |
+
logging.error(f"API Client Error: Status {status}, Response: {data.decode('utf-8')}")
|
| 129 |
+
return None
|
| 130 |
+
|
| 131 |
+
except (http.client.HTTPException, ConnectionError, socket.gaierror, TimeoutError, socket.timeout) as e:
|
| 132 |
+
logging.error(f"Network/HTTP error during API call: {e}. Attempt {attempt + 1}/{retries}")
|
| 133 |
+
if attempt + 1 == retries: return None
|
| 134 |
+
wait_time = REQUEST_DELAY_SECONDS * (1.5 ** attempt) + random.uniform(1, 3)
|
| 135 |
+
logging.warning(f"Waiting {wait_time:.2f} seconds before retry...")
|
| 136 |
+
time.sleep(wait_time)
|
| 137 |
+
except json.JSONDecodeError as e:
|
| 138 |
+
logging.error(f"Failed to decode API response: {e}. Response snippet: {data[:200] if data else 'N/A'}")
|
| 139 |
+
if attempt + 1 == retries: return None
|
| 140 |
+
wait_time = REQUEST_DELAY_SECONDS * (2 ** attempt) + random.uniform(1, 5)
|
| 141 |
+
time.sleep(wait_time)
|
| 142 |
+
except Exception as e:
|
| 143 |
+
logging.error(f"An unexpected error occurred during API call: {e}", exc_info=True)
|
| 144 |
+
if attempt + 1 == retries: return None
|
| 145 |
+
wait_time = REQUEST_DELAY_SECONDS * (1.5 ** attempt) + random.uniform(1, 3)
|
| 146 |
+
logging.warning(f"Waiting {wait_time:.2f} seconds before retry...")
|
| 147 |
+
time.sleep(wait_time)
|
| 148 |
+
|
| 149 |
+
logging.error(f"API call failed after {retries} retries for: {original_question[:50]}...")
|
| 150 |
+
return None
|
| 151 |
+
|
| 152 |
+
# --- Dataset Processing Function (rephrase_query_entry) ---
|
| 153 |
+
# (No changes needed here)
|
| 154 |
+
def rephrase_query_entry(example):
|
| 155 |
+
processed_example = example.copy()
|
| 156 |
+
original_query_list = example.get("query")
|
| 157 |
+
processed_example['query_rephrased_status'] = 'processing_retry'
|
| 158 |
+
|
| 159 |
+
if original_query_list is None:
|
| 160 |
+
processed_example['query_rephrased_status'] = 'skipped_missing_query_column'
|
| 161 |
+
processed_example['query_rephrased'] = example.get('query_rephrased') # Keep old value
|
| 162 |
+
return processed_example
|
| 163 |
+
if not isinstance(original_query_list, list):
|
| 164 |
+
processed_example['query_rephrased_status'] = 'skipped_query_not_list'
|
| 165 |
+
processed_example['query_rephrased'] = example.get('query_rephrased') # Keep old value
|
| 166 |
+
return processed_example
|
| 167 |
+
if not original_query_list:
|
| 168 |
+
processed_example['query_rephrased_status'] = 'skipped_query_list_empty'
|
| 169 |
+
processed_example['query_rephrased'] = example.get('query_rephrased') # Keep old value
|
| 170 |
+
return processed_example
|
| 171 |
+
|
| 172 |
+
user_question = None
|
| 173 |
+
for message in original_query_list:
|
| 174 |
+
if isinstance(message, dict) and message.get("role") == "user":
|
| 175 |
+
content = message.get("content")
|
| 176 |
+
if isinstance(content, str) and content.strip():
|
| 177 |
+
user_question = content
|
| 178 |
+
break
|
| 179 |
+
else:
|
| 180 |
+
processed_example['query_rephrased_status'] = 'skipped_invalid_user_content'
|
| 181 |
+
processed_example['query_rephrased'] = example.get('query_rephrased') # Keep old value
|
| 182 |
+
return processed_example
|
| 183 |
+
if not user_question:
|
| 184 |
+
processed_example['query_rephrased_status'] = 'skipped_no_user_content_found'
|
| 185 |
+
processed_example['query_rephrased'] = example.get('query_rephrased') # Keep old value
|
| 186 |
+
return processed_example
|
| 187 |
+
|
| 188 |
+
# logging.info(f"Retrying: {user_question[:60]}...")
|
| 189 |
+
rephrased_query_content = call_llm_api(user_question, API_KEY, API_HOST, API_PATH, LLM_MODEL)
|
| 190 |
+
|
| 191 |
+
if rephrased_query_content:
|
| 192 |
+
processed_example["query_rephrased"] = rephrased_query_content
|
| 193 |
+
processed_example['query_rephrased_status'] = 'success'
|
| 194 |
+
else:
|
| 195 |
+
# Keep the OLD 'query_rephrased' value if LLM call fails this time
|
| 196 |
+
processed_example['query_rephrased'] = example.get('query_rephrased')
|
| 197 |
+
processed_example['query_rephrased_status'] = 'failed_llm_call'
|
| 198 |
+
|
| 199 |
+
return processed_example
|
| 200 |
+
|
| 201 |
+
# --- Function to Save Dataset Atomically ---
|
| 202 |
+
# Saves to a temporary path then renames for safety.
|
| 203 |
+
def save_dataset_atomically(data_list, output_path, features):
|
| 204 |
+
"""Saves the list of data dictionaries atomically using the correct schema."""
|
| 205 |
+
if not data_list:
|
| 206 |
+
logging.info("No data provided for saving.")
|
| 207 |
+
return False
|
| 208 |
+
|
| 209 |
+
temp_output_path = output_path + "_saving" # Temporary directory
|
| 210 |
+
final_output_path = output_path
|
| 211 |
+
|
| 212 |
+
logging.info(f"Attempting to save {len(data_list)} examples to temp path {temp_output_path}...")
|
| 213 |
+
try:
|
| 214 |
+
# Create dataset from the list of dictionaries using the defined features
|
| 215 |
+
processed_dataset = Dataset.from_list(list(data_list), features=features) # Convert just in case
|
| 216 |
+
|
| 217 |
+
# Ensure parent directory exists
|
| 218 |
+
os.makedirs(os.path.dirname(final_output_path), exist_ok=True)
|
| 219 |
+
|
| 220 |
+
# Remove any previous temporary directory if it exists
|
| 221 |
+
if os.path.exists(temp_output_path):
|
| 222 |
+
logging.warning(f"Removing existing temporary save directory: {temp_output_path}")
|
| 223 |
+
shutil.rmtree(temp_output_path) # Use shutil for directories
|
| 224 |
+
|
| 225 |
+
# Save the dataset to the temporary path
|
| 226 |
+
processed_dataset.save_to_disk(temp_output_path)
|
| 227 |
+
logging.info(f"Successfully saved dataset to temporary path: {temp_output_path}")
|
| 228 |
+
|
| 229 |
+
# --- Atomic Rename ---
|
| 230 |
+
# Remove the final destination path if it exists
|
| 231 |
+
if os.path.exists(final_output_path):
|
| 232 |
+
logging.debug(f"Removing existing final destination directory before rename: {final_output_path}")
|
| 233 |
+
shutil.rmtree(final_output_path)
|
| 234 |
+
|
| 235 |
+
# Rename the temporary path to the final path
|
| 236 |
+
os.rename(temp_output_path, final_output_path)
|
| 237 |
+
logging.info(f"Successfully moved temporary save to final path: {final_output_path}")
|
| 238 |
+
return True
|
| 239 |
+
|
| 240 |
+
except Exception as e:
|
| 241 |
+
logging.error(f"Failed during atomic save process to {final_output_path}: {e}", exc_info=True)
|
| 242 |
+
# Attempt to clean up temporary directory if it still exists after failure
|
| 243 |
+
if os.path.exists(temp_output_path):
|
| 244 |
+
try:
|
| 245 |
+
shutil.rmtree(temp_output_path)
|
| 246 |
+
logging.info(f"Cleaned up temporary directory {temp_output_path} after error.")
|
| 247 |
+
except Exception as cleanup_e:
|
| 248 |
+
logging.error(f"Could not clean up temporary directory {temp_output_path} after error: {cleanup_e}")
|
| 249 |
+
# Fallback save attempt to JSON Lines (unchanged)
|
| 250 |
+
fallback_json_path = final_output_path + ".jsonl.failed_save" # Indicate it's a fallback
|
| 251 |
+
logging.warning(f"Attempting fallback save to JSON Lines file: {fallback_json_path}")
|
| 252 |
+
try:
|
| 253 |
+
with open(fallback_json_path, 'w', encoding='utf-8') as f:
|
| 254 |
+
for item in data_list:
|
| 255 |
+
f.write(json.dumps(item, ensure_ascii=False) + '\n')
|
| 256 |
+
logging.info(f"Successfully saved fallback JSON Lines file.")
|
| 257 |
+
except Exception as json_e:
|
| 258 |
+
logging.error(f"Fallback JSON save also failed: {json_e}", exc_info=True)
|
| 259 |
+
return False
|
| 260 |
+
|
| 261 |
+
# --- Function to Check if Retry is Needed ---
|
| 262 |
+
# (No changes needed here)
|
| 263 |
+
def needs_retry(example):
|
| 264 |
+
rephrased = example.get('query_rephrased')
|
| 265 |
+
status = example.get('query_rephrased_status')
|
| 266 |
+
# Retry if rephrased is missing OR status is anything other than 'success'
|
| 267 |
+
# This ensures failed/skipped items are retried.
|
| 268 |
+
retry_flag = (rephrased is None) or (status != 'success')
|
| 269 |
+
return retry_flag
|
| 270 |
+
|
| 271 |
+
# --- Main Execution ---
|
| 272 |
+
if __name__ == "__main__":
|
| 273 |
+
start_time = time.time()
|
| 274 |
+
logging.info("======================================================")
|
| 275 |
+
logging.info(f" Starting Dataset Processing - RETRY w/ PERIODIC SAVE")
|
| 276 |
+
logging.info(f" Saving progress every {SAVE_INTERVAL} processed items.")
|
| 277 |
+
logging.info("======================================================")
|
| 278 |
+
logging.info(f"Loading existing data from: {PROCESSED_DATA_PATH}")
|
| 279 |
+
logging.info(f"Intermediate and final output will be saved to: {FINAL_OUTPUT_PATH}")
|
| 280 |
+
|
| 281 |
+
# --- Load Existing Processed Dataset ---
|
| 282 |
+
if not os.path.exists(PROCESSED_DATA_PATH):
|
| 283 |
+
logging.error(f"Existing data directory not found at '{PROCESSED_DATA_PATH}'. Cannot run retry mode.")
|
| 284 |
+
sys.exit(1)
|
| 285 |
+
|
| 286 |
+
logging.info(f"Loading existing dataset from {PROCESSED_DATA_PATH}...")
|
| 287 |
+
try:
|
| 288 |
+
existing_dataset = Dataset.load_from_disk(PROCESSED_DATA_PATH)
|
| 289 |
+
# Get features *before* converting to list
|
| 290 |
+
dataset_features = existing_dataset.features
|
| 291 |
+
logging.info(f"Dataset features detected: {dataset_features}")
|
| 292 |
+
# Convert to a list of dictionaries for in-memory modification
|
| 293 |
+
results_list = existing_dataset.to_list()
|
| 294 |
+
total_examples = len(results_list)
|
| 295 |
+
logging.info(f"Loaded {total_examples} examples.")
|
| 296 |
+
except Exception as e:
|
| 297 |
+
logging.error(f"Failed to load dataset from {PROCESSED_DATA_PATH}: {e}", exc_info=True)
|
| 298 |
+
# Check if the final path exists from a previous run - maybe load that?
|
| 299 |
+
# For now, exiting is safer to avoid inconsistent states.
|
| 300 |
+
# if os.path.exists(FINAL_OUTPUT_PATH):
|
| 301 |
+
# logging.warning(f"Consider manually checking/using the existing data at {FINAL_OUTPUT_PATH}")
|
| 302 |
+
sys.exit(1)
|
| 303 |
+
|
| 304 |
+
# --- Identify Indices to Retry ---
|
| 305 |
+
logging.info("Identifying examples needing retry...")
|
| 306 |
+
indices_to_retry = [
|
| 307 |
+
i for i, example in enumerate(tqdm(results_list, desc="Checking examples")) if needs_retry(example)
|
| 308 |
+
]
|
| 309 |
+
num_to_retry = len(indices_to_retry)
|
| 310 |
+
|
| 311 |
+
if num_to_retry == 0:
|
| 312 |
+
logging.info("No examples found needing retry based on the criteria ('query_rephrased' is None or status != 'success').")
|
| 313 |
+
logging.info(f"Saving the existing dataset to the final location '{FINAL_OUTPUT_PATH}' as is...")
|
| 314 |
+
if not save_dataset_atomically(results_list, FINAL_OUTPUT_PATH, dataset_features): # Use atomic save
|
| 315 |
+
logging.error("Failed to save the dataset to the final location even though no retries were needed.")
|
| 316 |
+
sys.exit(0)
|
| 317 |
+
|
| 318 |
+
logging.info(f"Identified {num_to_retry} examples to retry out of {total_examples}.")
|
| 319 |
+
|
| 320 |
+
# --- Prepare for Concurrent Retries ---
|
| 321 |
+
processed_count_total = 0 # Total processed in this run
|
| 322 |
+
processed_since_last_save = 0 # Counter for periodic saving
|
| 323 |
+
last_save_time = time.time() # Track time for saving message
|
| 324 |
+
|
| 325 |
+
logging.info("Starting concurrent retries with periodic saving...")
|
| 326 |
+
|
| 327 |
+
# --- ThreadPoolExecutor for Concurrency ---
|
| 328 |
+
with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
|
| 329 |
+
# Submit tasks only for the identified indices
|
| 330 |
+
futures = {
|
| 331 |
+
executor.submit(rephrase_query_entry, results_list[i]): i
|
| 332 |
+
for i in indices_to_retry
|
| 333 |
+
}
|
| 334 |
+
|
| 335 |
+
try:
|
| 336 |
+
# Initialize progress bar for retries
|
| 337 |
+
pbar = tqdm(total=num_to_retry, desc="Retrying examples", unit="example")
|
| 338 |
+
# Process futures as they complete
|
| 339 |
+
for future in concurrent.futures.as_completed(futures):
|
| 340 |
+
original_index = futures[future] # Get the original list index
|
| 341 |
+
try:
|
| 342 |
+
# Get the result (the updated dictionary)
|
| 343 |
+
updated_example_dict = future.result()
|
| 344 |
+
# --- IMMEDIATE UPDATE of the main list ---
|
| 345 |
+
results_list[original_index] = updated_example_dict
|
| 346 |
+
pbar.set_postfix({"LastStatus": updated_example_dict.get('query_rephrased_status', 'N/A')}, refresh=True)
|
| 347 |
+
|
| 348 |
+
except Exception as exc:
|
| 349 |
+
# Catch potential exceptions *from* the rephrase_query_entry function
|
| 350 |
+
logging.error(f'Retry task for index {original_index} encountered an exception: {exc}', exc_info=True)
|
| 351 |
+
# Create an error placeholder and update the main list
|
| 352 |
+
error_placeholder = results_list[original_index].copy() # Start with original data
|
| 353 |
+
error_placeholder['query_rephrased_status'] = f'failed_retry_exception_{type(exc).__name__}'
|
| 354 |
+
# Keep the old query_rephrased value
|
| 355 |
+
results_list[original_index] = error_placeholder
|
| 356 |
+
pbar.set_postfix({"LastStatus": error_placeholder['query_rephrased_status']}, refresh=True)
|
| 357 |
+
|
| 358 |
+
finally:
|
| 359 |
+
# Increment counters and update progress bar
|
| 360 |
+
processed_count_total += 1
|
| 361 |
+
processed_since_last_save += 1
|
| 362 |
+
pbar.update(1)
|
| 363 |
+
|
| 364 |
+
# --- Periodic Save Check ---
|
| 365 |
+
if processed_since_last_save >= SAVE_INTERVAL:
|
| 366 |
+
current_time = time.time()
|
| 367 |
+
time_since_last = current_time - last_save_time
|
| 368 |
+
logging.info(f"\n--- Processed {processed_since_last_save} items (Total: {processed_count_total}/{num_to_retry}). Time since last save: {time_since_last:.1f}s. Saving progress... ---")
|
| 369 |
+
if save_dataset_atomically(results_list, FINAL_OUTPUT_PATH, dataset_features):
|
| 370 |
+
logging.info(f"--- Progress successfully saved to {FINAL_OUTPUT_PATH} ---")
|
| 371 |
+
processed_since_last_save = 0 # Reset counter
|
| 372 |
+
last_save_time = current_time
|
| 373 |
+
else:
|
| 374 |
+
logging.error(f"--- FAILED TO SAVE PROGRESS! Check errors above. Will retry saving later. ---")
|
| 375 |
+
# Don't reset the counter, maybe the next save will work
|
| 376 |
+
|
| 377 |
+
except KeyboardInterrupt:
|
| 378 |
+
logging.warning("\nCtrl+C detected! Attempting final save...")
|
| 379 |
+
# Let the finally block handle the save
|
| 380 |
+
|
| 381 |
+
except Exception as e:
|
| 382 |
+
logging.error(f"An unexpected error occurred during the main retry loop: {e}", exc_info=True)
|
| 383 |
+
logging.error("Attempting final save...")
|
| 384 |
+
# Let the finally block handle the save
|
| 385 |
+
|
| 386 |
+
finally:
|
| 387 |
+
# --- This block executes after the loop finishes, OR if an exception/interrupt occurs ---
|
| 388 |
+
if 'pbar' in locals() and pbar is not None:
|
| 389 |
+
pbar.close()
|
| 390 |
+
|
| 391 |
+
logging.info("--- Processing loop finished or interrupted. ---")
|
| 392 |
+
|
| 393 |
+
# --- Final Save Attempt ---
|
| 394 |
+
# No need to update results_list again, it was updated incrementally.
|
| 395 |
+
logging.info(f"Attempting final save of the dataset ({len(results_list)} items) to: {FINAL_OUTPUT_PATH}")
|
| 396 |
+
if save_dataset_atomically(results_list, FINAL_OUTPUT_PATH, dataset_features):
|
| 397 |
+
logging.info("--- Final dataset state saved successfully. ---")
|
| 398 |
+
else:
|
| 399 |
+
logging.error(">>> FINAL SAVE FAILED! <<< Check logs. Fallback JSON file might exist.")
|
| 400 |
+
|
| 401 |
+
# --- Final Verification (Optional but Recommended) ---
|
| 402 |
+
logging.info("------------------------------------------------------")
|
| 403 |
+
logging.info("Verification: Attempting to load final saved dataset...")
|
| 404 |
+
try:
|
| 405 |
+
final_reloaded_dataset = Dataset.load_from_disk(FINAL_OUTPUT_PATH)
|
| 406 |
+
logging.info(f"Successfully reloaded final dataset with {len(final_reloaded_dataset)} examples from {FINAL_OUTPUT_PATH}.")
|
| 407 |
+
|
| 408 |
+
# Simple status count
|
| 409 |
+
status_counts = {}
|
| 410 |
+
none_rephrased_count = 0
|
| 411 |
+
for ex in final_reloaded_dataset:
|
| 412 |
+
status = ex.get('query_rephrased_status', 'unknown_status')
|
| 413 |
+
status_counts[status] = status_counts.get(status, 0) + 1
|
| 414 |
+
if ex.get('query_rephrased') is None or not str(ex.get('query_rephrased')).strip():
|
| 415 |
+
none_rephrased_count += 1
|
| 416 |
+
|
| 417 |
+
logging.info("Final status counts:")
|
| 418 |
+
for status, count in sorted(status_counts.items()):
|
| 419 |
+
logging.info(f" - {status}: {count}")
|
| 420 |
+
|
| 421 |
+
final_success = status_counts.get('success', 0)
|
| 422 |
+
final_failed = sum(count for st, count in status_counts.items() if st and (st.startswith('failed_') or st == 'processing_retry')) # Items potentially stuck
|
| 423 |
+
final_skipped = sum(count for st, count in status_counts.items() if st and st.startswith('skipped_'))
|
| 424 |
+
other_count = len(final_reloaded_dataset) - final_success - final_failed - final_skipped
|
| 425 |
+
|
| 426 |
+
logging.info(f"Summary: Success={final_success}, Failed/Incomplete={final_failed}, Skipped={final_skipped}, Other={other_count}")
|
| 427 |
+
if none_rephrased_count > 0:
|
| 428 |
+
logging.warning(f"WARNING: {none_rephrased_count} items have None/empty 'query_rephrased' in the final dataset.")
|
| 429 |
+
if final_failed > 0:
|
| 430 |
+
logging.warning(f"WARNING: {final_failed} items did not reach 'success' or 'skipped' status.")
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
except FileNotFoundError:
|
| 434 |
+
logging.error(f"Verification failed: Final dataset directory not found at {FINAL_OUTPUT_PATH}. Final save likely failed.")
|
| 435 |
+
except Exception as e:
|
| 436 |
+
logging.error(f"Verification failed: Could not reload/verify final dataset from {FINAL_OUTPUT_PATH}: {e}", exc_info=True)
|
| 437 |
+
|
| 438 |
+
# --- Script End ---
|
| 439 |
+
end_time = time.time()
|
| 440 |
+
logging.info("------------------------------------------------------")
|
| 441 |
+
logging.info(f"Script finished in {end_time - start_time:.2f} seconds.")
|
| 442 |
+
logging.info("======================================================")
|
r1-a/dataset/retts.py
ADDED
|
@@ -0,0 +1,559 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
import os
|
| 3 |
+
import argparse
|
| 4 |
+
import torch
|
| 5 |
+
import re
|
| 6 |
+
import jiwer
|
| 7 |
+
from datasets import load_from_disk, concatenate_datasets, Dataset, Features, Value, Audio # Keep Audio for potential output type hint if needed
|
| 8 |
+
from transformers import pipeline
|
| 9 |
+
import logging
|
| 10 |
+
import time
|
| 11 |
+
import soundfile as sf # For checking validity
|
| 12 |
+
import librosa # For loading audio in batch function
|
| 13 |
+
import numpy as np
|
| 14 |
+
import collections
|
| 15 |
+
import pyarrow as pa
|
| 16 |
+
|
| 17 |
+
# --- 配置日志 ---
|
| 18 |
+
log_formatter = logging.Formatter('%(asctime)s - %(levelname)s - [Shard %(shard_index)s] - %(message)s')
|
| 19 |
+
logger = logging.getLogger()
|
| 20 |
+
logger.setLevel(logging.INFO)
|
| 21 |
+
ch = logging.StreamHandler()
|
| 22 |
+
ch.setLevel(logging.INFO)
|
| 23 |
+
ch.setFormatter(log_formatter)
|
| 24 |
+
logger.addHandler(ch)
|
| 25 |
+
fh = None # File handler setup in main
|
| 26 |
+
|
| 27 |
+
# --- 常量与参数定义 ---
|
| 28 |
+
MODEL_ID = "openai/whisper-large-v3"
|
| 29 |
+
DATASET_PATH = "/home/chenyifu/audio-r1/r1-a/final_dataset/preference_relative" # ADJUST IF NEEDED
|
| 30 |
+
OUTPUT_DIR = "/home/chenyifu/audio-r1/r1-a/final_dataset/preference_relative_processed_shards" # ADJUST IF NEEDED
|
| 31 |
+
LOG_DIR = os.path.join(OUTPUT_DIR, "logs")
|
| 32 |
+
NUM_SHARDS = 50
|
| 33 |
+
MIN_AUDIO_DURATION_MS = 100
|
| 34 |
+
TARGET_SR = 16000 # Whisper expected sample rate
|
| 35 |
+
|
| 36 |
+
# --- 文本规范化函数 ---
|
| 37 |
+
def normalize_text(text):
|
| 38 |
+
if text is None:
|
| 39 |
+
return ""
|
| 40 |
+
text = str(text).lower()
|
| 41 |
+
text = re.sub(r'[^\w\s\u4e00-\u9fff,。!?、]', '', text)
|
| 42 |
+
text = re.sub(r'\s+', ' ', text).strip()
|
| 43 |
+
return text
|
| 44 |
+
|
| 45 |
+
# --- 音频文件预检查函数 (检查路径) ---
|
| 46 |
+
# (This function remains largely the same as it already worked with paths)
|
| 47 |
+
def check_audio_file_validity(example, shard_idx_for_log=None):
|
| 48 |
+
is_valid = False
|
| 49 |
+
error_msg = "Unknown error"
|
| 50 |
+
duration_ms = 0
|
| 51 |
+
audio_path = example.get("question_audio") # Directly get the path string
|
| 52 |
+
log_prefix = f"[Shard {shard_idx_for_log}] " if shard_idx_for_log is not None else ""
|
| 53 |
+
|
| 54 |
+
if audio_path and isinstance(audio_path, str):
|
| 55 |
+
if os.path.exists(audio_path):
|
| 56 |
+
try:
|
| 57 |
+
info = sf.info(audio_path)
|
| 58 |
+
duration_ms = int(info.duration * 1000)
|
| 59 |
+
if info.samplerate > 0 and info.frames > 0:
|
| 60 |
+
if duration_ms >= MIN_AUDIO_DURATION_MS:
|
| 61 |
+
is_valid = True
|
| 62 |
+
error_msg = None
|
| 63 |
+
else:
|
| 64 |
+
error_msg = f"Audio duration {duration_ms}ms < minimum {MIN_AUDIO_DURATION_MS}ms"
|
| 65 |
+
else:
|
| 66 |
+
error_msg = "Invalid audio properties (samplerate/frames <= 0)"
|
| 67 |
+
except Exception as e:
|
| 68 |
+
logger.warning(f"{log_prefix}Cannot read info/validate file {audio_path}: {type(e).__name__}")
|
| 69 |
+
error_msg = f"Cannot read/validate audio: {type(e).__name__}"
|
| 70 |
+
else:
|
| 71 |
+
error_msg = "Audio file not found"
|
| 72 |
+
elif audio_path is None:
|
| 73 |
+
error_msg = "Audio path is missing or null"
|
| 74 |
+
else:
|
| 75 |
+
error_msg = f"Audio path is not a string (type: {type(audio_path).__name__})"
|
| 76 |
+
|
| 77 |
+
return {
|
| 78 |
+
"audio_is_valid": is_valid,
|
| 79 |
+
"audio_check_error": error_msg,
|
| 80 |
+
"audio_duration_ms": duration_ms
|
| 81 |
+
# Don't add the original path back here, it's already in the dataset
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
# --- 核心处理函数 (批处理 - 加载音频路径) ---
|
| 86 |
+
def check_audio_quality_batch(batch, asr_pipeline, wer_threshold, target_sr, shard_idx_for_log=None):
|
| 87 |
+
log_prefix = f"[Shard {shard_idx_for_log}] " if shard_idx_for_log is not None else ""
|
| 88 |
+
results = {"asr_transcription": [], "wer": [], "is_bad_tts": [], "error_message": []}
|
| 89 |
+
original_texts = batch.get("question_text", [])
|
| 90 |
+
audio_paths = batch.get("question_audio", []) # Get list of paths
|
| 91 |
+
|
| 92 |
+
num_samples_in_batch = len(audio_paths)
|
| 93 |
+
if not audio_paths or not original_texts or len(audio_paths) != len(original_texts):
|
| 94 |
+
logger.warning(f"{log_prefix}Batch inconsistency or empty data. Paths: {len(audio_paths)}, Text: {len(original_texts)}")
|
| 95 |
+
num_samples = max(len(audio_paths), len(original_texts))
|
| 96 |
+
results["asr_transcription"] = [""] * num_samples
|
| 97 |
+
results["wer"] = [1.0] * num_samples
|
| 98 |
+
results["is_bad_tts"] = [True] * num_samples
|
| 99 |
+
results["error_message"] = ["Inconsistent batch data or missing paths/text"] * num_samples
|
| 100 |
+
return results
|
| 101 |
+
|
| 102 |
+
batch_load_start_time = time.time()
|
| 103 |
+
loaded_audios = []
|
| 104 |
+
load_errors = [None] * num_samples_in_batch # Track loading errors per sample
|
| 105 |
+
|
| 106 |
+
# --- 加载批次中的所有音频 ---
|
| 107 |
+
for i, path in enumerate(audio_paths):
|
| 108 |
+
try:
|
| 109 |
+
if not path or not isinstance(path, str):
|
| 110 |
+
raise ValueError("Invalid audio path")
|
| 111 |
+
# Load using librosa, force mono, resample to target_sr
|
| 112 |
+
audio_array, sample_rate = librosa.load(path, sr=target_sr, mono=True)
|
| 113 |
+
loaded_audios.append(audio_array)
|
| 114 |
+
except Exception as e:
|
| 115 |
+
logger.warning(f"{log_prefix}Failed to load audio file '{path}': {type(e).__name__}. Skipping for ASR.")
|
| 116 |
+
loaded_audios.append(None) # Use None as placeholder for failed loads
|
| 117 |
+
load_errors[i] = f"Audio load failed: {type(e).__name__}"
|
| 118 |
+
|
| 119 |
+
batch_load_end_time = time.time()
|
| 120 |
+
logger.debug(f"{log_prefix}Loaded {len([a for a in loaded_audios if a is not None])}/{num_samples_in_batch} audios in {batch_load_end_time - batch_load_start_time:.2f} sec.")
|
| 121 |
+
|
| 122 |
+
# Filter out None placeholders before sending to ASR pipeline?
|
| 123 |
+
# Option 1: Send only valid audios (might complicate matching results back)
|
| 124 |
+
# Option 2: Send list including None/empty arrays, let pipeline handle (or pre-handle)
|
| 125 |
+
# Let's try Option 2 with pre-handling: Replace None with empty array for pipeline input
|
| 126 |
+
pipeline_inputs = []
|
| 127 |
+
valid_indices = [] # Track indices of samples sent to pipeline
|
| 128 |
+
for i, audio_data in enumerate(loaded_audios):
|
| 129 |
+
if audio_data is not None and len(audio_data) > 0: # Check if loading succeeded and audio not empty
|
| 130 |
+
pipeline_inputs.append(audio_data)
|
| 131 |
+
valid_indices.append(i)
|
| 132 |
+
# else: keep load_errors[i] message
|
| 133 |
+
|
| 134 |
+
asr_results_list = [None] * num_samples_in_batch # Initialize results list matching original batch size
|
| 135 |
+
|
| 136 |
+
# --- ASR 推理 (仅对成功加载的音频) ---
|
| 137 |
+
if pipeline_inputs: # Only run pipeline if there are valid audios
|
| 138 |
+
batch_asr_start_time = time.time()
|
| 139 |
+
try:
|
| 140 |
+
# Pass the list of NumPy arrays directly to the pipeline
|
| 141 |
+
asr_outputs = asr_pipeline(pipeline_inputs, generate_kwargs={"language": "zh", "task": "transcribe"})
|
| 142 |
+
|
| 143 |
+
if not isinstance(asr_outputs, list):
|
| 144 |
+
asr_outputs = [asr_outputs] # Ensure it's a list
|
| 145 |
+
|
| 146 |
+
# Map results back to original batch positions using valid_indices
|
| 147 |
+
if len(asr_outputs) == len(valid_indices):
|
| 148 |
+
for idx, result in zip(valid_indices, asr_outputs):
|
| 149 |
+
asr_results_list[idx] = result # Place result at the correct original index
|
| 150 |
+
else:
|
| 151 |
+
logger.error(f"{log_prefix}ASR output count ({len(asr_outputs)}) mismatch with valid input count ({len(valid_indices)}). Marking all in batch as error.")
|
| 152 |
+
# Mark all samples in the batch with an error if counts mismatch
|
| 153 |
+
for i in range(num_samples_in_batch):
|
| 154 |
+
if load_errors[i] is None: # If loading didn't fail, mark as ASR mismatch
|
| 155 |
+
load_errors[i] = "ASR count mismatch error"
|
| 156 |
+
|
| 157 |
+
# --- Error Handling for ASR Pipeline ---
|
| 158 |
+
# Catch errors specifically from the pipeline call
|
| 159 |
+
except ValueError as ve: # e.g., internal batching errors if any remain
|
| 160 |
+
logger.error(f"{log_prefix}ValueError during ASR pipeline processing: {ve}", exc_info=True)
|
| 161 |
+
for idx in valid_indices: # Mark only those sent to pipeline as failed
|
| 162 |
+
asr_results_list[idx] = "ERROR: ASR ValueError" # Placeholder or error indicator
|
| 163 |
+
if load_errors[idx] is None: load_errors[idx] = f"ASR ValueError: {str(ve)[:100]}"
|
| 164 |
+
except torch.cuda.OutOfMemoryError:
|
| 165 |
+
logger.error(f"{log_prefix}CUDA OutOfMemoryError during ASR batch processing.")
|
| 166 |
+
torch.cuda.empty_cache()
|
| 167 |
+
for idx in valid_indices:
|
| 168 |
+
asr_results_list[idx] = "ERROR: ASR OOM"
|
| 169 |
+
if load_errors[idx] is None: load_errors[idx] = "ASR CUDA OOM"
|
| 170 |
+
except Exception as e:
|
| 171 |
+
logger.error(f"{log_prefix}Exception during ASR pipeline processing: {e}", exc_info=True)
|
| 172 |
+
for idx in valid_indices:
|
| 173 |
+
asr_results_list[idx] = "ERROR: ASR Exception"
|
| 174 |
+
if load_errors[idx] is None: load_errors[idx] = f"ASR Exception: {str(e)[:100]}"
|
| 175 |
+
|
| 176 |
+
batch_asr_end_time = time.time()
|
| 177 |
+
logger.debug(f"{log_prefix}ASR processed {len(valid_indices)} audios in {batch_asr_end_time - batch_asr_start_time:.2f} sec.")
|
| 178 |
+
|
| 179 |
+
# --- 计算 WER (遍历原始批次大小) ---
|
| 180 |
+
for i in range(num_samples_in_batch):
|
| 181 |
+
transcription = ""
|
| 182 |
+
wer = 1.0 # Default to max error
|
| 183 |
+
is_bad = True
|
| 184 |
+
error_msg = load_errors[i] # Start with potential loading error
|
| 185 |
+
|
| 186 |
+
asr_result = asr_results_list[i]
|
| 187 |
+
|
| 188 |
+
if error_msg is None: # If no loading error, proceed with ASR result
|
| 189 |
+
if isinstance(asr_result, dict) and "text" in asr_result:
|
| 190 |
+
transcription = asr_result["text"]
|
| 191 |
+
original_text = original_texts[i]
|
| 192 |
+
|
| 193 |
+
norm_original = normalize_text(original_text)
|
| 194 |
+
norm_transcription = normalize_text(transcription)
|
| 195 |
+
|
| 196 |
+
if not norm_original:
|
| 197 |
+
wer = 1.0 if norm_transcription else 0.0
|
| 198 |
+
is_bad = True if norm_transcription else False
|
| 199 |
+
error_msg = "Original text normalized to empty" if is_bad else "Original text normalized to empty, transcription also empty"
|
| 200 |
+
else:
|
| 201 |
+
try:
|
| 202 |
+
wer = jiwer.wer(norm_original, norm_transcription)
|
| 203 |
+
wer = min(wer, 1.0) # Clamp WER
|
| 204 |
+
is_bad = wer > wer_threshold
|
| 205 |
+
except ValueError as e:
|
| 206 |
+
wer = 1.0
|
| 207 |
+
is_bad = True
|
| 208 |
+
logger.warning(f"{log_prefix}Jiwer WER calculation error for idx {i}. Setting WER to 1.0. Error: {e}")
|
| 209 |
+
error_msg = f"WER calculation error: {e}"
|
| 210 |
+
except Exception as e:
|
| 211 |
+
wer = 1.0
|
| 212 |
+
is_bad = True
|
| 213 |
+
logger.error(f"{log_prefix}Unexpected error during WER calculation idx {i}: {e}", exc_info=True)
|
| 214 |
+
error_msg = f"Unexpected WER error: {e}"
|
| 215 |
+
elif isinstance(asr_result, str) and "ERROR:" in asr_result:
|
| 216 |
+
# Handle error strings passed from ASR exception handling
|
| 217 |
+
error_msg = asr_result
|
| 218 |
+
wer = 1.0
|
| 219 |
+
is_bad = True
|
| 220 |
+
else:
|
| 221 |
+
# ASR didn't run (load failed) or returned unexpected format
|
| 222 |
+
# error_msg should already be set from load_errors
|
| 223 |
+
# If error_msg is somehow still None, set a generic one
|
| 224 |
+
if error_msg is None:
|
| 225 |
+
error_msg = "ASR did not produce valid output"
|
| 226 |
+
wer = 1.0
|
| 227 |
+
is_bad = True
|
| 228 |
+
|
| 229 |
+
results["asr_transcription"].append(transcription)
|
| 230 |
+
results["wer"].append(wer)
|
| 231 |
+
results["is_bad_tts"].append(is_bad)
|
| 232 |
+
results["error_message"].append(error_msg)
|
| 233 |
+
|
| 234 |
+
return results
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
# --- 统计信息记录函数 ---
|
| 238 |
+
# (This function remains the same, as it operates on the processed shard data)
|
| 239 |
+
def log_shard_statistics(processed_shard, shard_index, wer_threshold, processing_time):
|
| 240 |
+
log_prefix = f"[Shard {shard_index}] "
|
| 241 |
+
logger.info(f"{log_prefix}--- Shard {shard_index} Statistics ---")
|
| 242 |
+
# ... (rest of the function is identical to the previous version) ...
|
| 243 |
+
total_samples = len(processed_shard)
|
| 244 |
+
logger.info(f"{log_prefix}Total samples processed in this shard: {total_samples}")
|
| 245 |
+
if total_samples == 0:
|
| 246 |
+
logger.info(f"{log_prefix}Shard was empty, no statistics to report.")
|
| 247 |
+
logger.info(f"{log_prefix}--- End Shard {shard_index} Statistics ---")
|
| 248 |
+
return
|
| 249 |
+
|
| 250 |
+
logger.info(f"{log_prefix}Processing time for this shard: {processing_time:.2f} seconds")
|
| 251 |
+
if processing_time > 0:
|
| 252 |
+
logger.info(f"{log_prefix}Overall processing speed: {total_samples / processing_time:.2f} samples/sec")
|
| 253 |
+
logger.info(f"{log_prefix}WER threshold used: {wer_threshold}")
|
| 254 |
+
|
| 255 |
+
required_cols = ['is_bad_tts', 'wer', 'error_message', 'question_text', 'asr_transcription']
|
| 256 |
+
if not all(col in processed_shard.column_names for col in required_cols):
|
| 257 |
+
logger.error(f"{log_prefix}Processed shard is missing required columns for statistics ({required_cols}). Skipping detailed stats.")
|
| 258 |
+
logger.info(f"{log_prefix}--- End Shard {shard_index} Statistics ---")
|
| 259 |
+
return
|
| 260 |
+
|
| 261 |
+
try:
|
| 262 |
+
bad_tts_count = sum(processed_shard['is_bad_tts'])
|
| 263 |
+
bad_tts_percentage = (bad_tts_count / total_samples) * 100 if total_samples > 0 else 0
|
| 264 |
+
logger.info(f"{log_prefix}Bad TTS samples (WER > {wer_threshold} or Error): {bad_tts_count} ({bad_tts_percentage:.2f}%)")
|
| 265 |
+
logger.info(f"{log_prefix}Good TTS samples (WER <= {wer_threshold}): {total_samples - bad_tts_count} ({100 - bad_tts_percentage:.2f}%)")
|
| 266 |
+
|
| 267 |
+
wer_scores = [w for w in processed_shard['wer'] if w is not None and not np.isnan(w)]
|
| 268 |
+
if wer_scores:
|
| 269 |
+
logger.info(f"{log_prefix}WER Score Distribution (for samples where WER could be calculated):")
|
| 270 |
+
logger.info(f"{log_prefix} Count: {len(wer_scores)}")
|
| 271 |
+
logger.info(f"{log_prefix} Min: {np.min(wer_scores):.4f}")
|
| 272 |
+
logger.info(f"{log_prefix} Max: {np.max(wer_scores):.4f}") # Should be <= 1.0 now
|
| 273 |
+
logger.info(f"{log_prefix} Mean: {np.mean(wer_scores):.4f}")
|
| 274 |
+
logger.info(f"{log_prefix} Median: {np.median(wer_scores):.4f}")
|
| 275 |
+
q25, q75 = np.percentile(wer_scores, [25, 75])
|
| 276 |
+
logger.info(f"{log_prefix} 25th Percentile: {q25:.4f}")
|
| 277 |
+
logger.info(f"{log_prefix} 75th Percentile: {q75:.4f}")
|
| 278 |
+
else:
|
| 279 |
+
logger.info(f"{log_prefix}WER Score Distribution: No valid WER scores found.")
|
| 280 |
+
|
| 281 |
+
error_messages = [msg for msg in processed_shard['error_message'] if msg]
|
| 282 |
+
if error_messages:
|
| 283 |
+
error_counts = collections.Counter(error_messages)
|
| 284 |
+
logger.info(f"{log_prefix}Error Message Summary (Top 10):")
|
| 285 |
+
for msg, count in error_counts.most_common(10):
|
| 286 |
+
logger.info(f"{log_prefix} - \"{msg}\": {count} occurrences")
|
| 287 |
+
if len(error_counts) > 10:
|
| 288 |
+
logger.info(f"{log_prefix} ... ({len(error_counts) - 10} more error types)")
|
| 289 |
+
else:
|
| 290 |
+
logger.info(f"{log_prefix}Error Message Summary: No processing errors recorded.")
|
| 291 |
+
|
| 292 |
+
logger.info(f"\n{log_prefix}--- Example Good TTS Samples (WER <= {wer_threshold}) ---")
|
| 293 |
+
# Use select for potentially large datasets, disable caching for filter
|
| 294 |
+
good_samples_indices = [i for i, bad in enumerate(processed_shard['is_bad_tts']) if not bad]
|
| 295 |
+
num_good_to_show = min(5, len(good_samples_indices))
|
| 296 |
+
if num_good_to_show > 0:
|
| 297 |
+
# Select the samples using indices; this is faster than filter for small selects
|
| 298 |
+
good_samples_view = processed_shard.select(good_samples_indices[:num_good_to_show])
|
| 299 |
+
for i in range(num_good_to_show):
|
| 300 |
+
sample = good_samples_view[i]
|
| 301 |
+
logger.info(f"{log_prefix} Example {i+1}:")
|
| 302 |
+
logger.info(f"{log_prefix} Original Text: {sample['question_text']}")
|
| 303 |
+
logger.info(f"{log_prefix} ASR Transcript: {sample['asr_transcription']}")
|
| 304 |
+
logger.info(f"{log_prefix} WER: {sample['wer']:.4f}")
|
| 305 |
+
logger.info(f"{log_prefix} Audio Path: {sample['question_audio']}") # Show path
|
| 306 |
+
else:
|
| 307 |
+
logger.info(f"{log_prefix} No good samples found in this shard.")
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
logger.info(f"\n{log_prefix}--- Example Bad TTS Samples (WER > {wer_threshold} or Error) ---")
|
| 311 |
+
bad_samples_indices = [i for i, bad in enumerate(processed_shard['is_bad_tts']) if bad]
|
| 312 |
+
num_bad_to_show = min(5, len(bad_samples_indices))
|
| 313 |
+
if num_bad_to_show > 0:
|
| 314 |
+
bad_samples_view = processed_shard.select(bad_samples_indices[:num_bad_to_show])
|
| 315 |
+
for i in range(num_bad_to_show):
|
| 316 |
+
sample = bad_samples_view[i]
|
| 317 |
+
logger.info(f"{log_prefix} Example {i+1}:")
|
| 318 |
+
logger.info(f"{log_prefix} Original Text: {sample['question_text']}")
|
| 319 |
+
logger.info(f"{log_prefix} ASR Transcript: {sample['asr_transcription']}")
|
| 320 |
+
logger.info(f"{log_prefix} WER: {sample['wer']:.4f}")
|
| 321 |
+
logger.info(f"{log_prefix} Error Msg: {sample['error_message']}")
|
| 322 |
+
logger.info(f"{log_prefix} Audio Path: {sample['question_audio']}") # Show path
|
| 323 |
+
else:
|
| 324 |
+
logger.info(f"{log_prefix} No bad samples found in this shard.")
|
| 325 |
+
|
| 326 |
+
except Exception as e:
|
| 327 |
+
logger.error(f"{log_prefix}Error generating statistics: {e}", exc_info=True)
|
| 328 |
+
|
| 329 |
+
logger.info(f"{log_prefix}--- End Shard {shard_index} Statistics ---")
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
# --- 主函数 ---
|
| 333 |
+
def main():
|
| 334 |
+
global fh
|
| 335 |
+
parser = argparse.ArgumentParser(description="Process a shard of the dataset using Whisper ASR, loading audio from paths.")
|
| 336 |
+
# ... (Argument parsing remains the same) ...
|
| 337 |
+
parser.add_argument("--shard_index", type=int, required=True, help=f"Index of the shard to process (0 to {NUM_SHARDS-1}).")
|
| 338 |
+
parser.add_argument("--gpu_id", type=int, required=True, help="GPU ID to use for this process.")
|
| 339 |
+
parser.add_argument("--wer_threshold", type=float, default=0.4, help="WER threshold to mark TTS as bad.")
|
| 340 |
+
parser.add_argument("--pipeline_batch_size", type=int, default=8, help="Internal batch size for the ASR pipeline.")
|
| 341 |
+
parser.add_argument("--map_batch_size", type=int, default=16, help="Batch size for the datasets.map function (how many rows passed to batch func).")
|
| 342 |
+
parser.add_argument("--num_check_workers", type=int, default=4, help="Number of workers for audio pre-check map.")
|
| 343 |
+
args = parser.parse_args()
|
| 344 |
+
shard_index = args.shard_index
|
| 345 |
+
# ... (Rest of argument setup, logging setup, GPU setup - same as before) ...
|
| 346 |
+
gpu_id = args.gpu_id
|
| 347 |
+
wer_threshold = args.wer_threshold
|
| 348 |
+
pipeline_batch_size = args.pipeline_batch_size
|
| 349 |
+
map_batch_size = args.map_batch_size
|
| 350 |
+
num_check_workers = args.num_check_workers
|
| 351 |
+
|
| 352 |
+
os.makedirs(LOG_DIR, exist_ok=True)
|
| 353 |
+
log_file = os.path.join(LOG_DIR, f"shard_{shard_index}_gpu_{gpu_id}.log")
|
| 354 |
+
fh = logging.FileHandler(log_file, mode='w')
|
| 355 |
+
fh.setLevel(logging.INFO)
|
| 356 |
+
fh.setFormatter(log_formatter)
|
| 357 |
+
logger.addHandler(fh)
|
| 358 |
+
|
| 359 |
+
old_factory = logging.getLogRecordFactory()
|
| 360 |
+
def record_factory(*args, **kwargs):
|
| 361 |
+
record = old_factory(*args, **kwargs)
|
| 362 |
+
record.shard_index = shard_index
|
| 363 |
+
return record
|
| 364 |
+
logging.setLogRecordFactory(record_factory)
|
| 365 |
+
|
| 366 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
|
| 367 |
+
device = f"cuda:0"
|
| 368 |
+
logger.info(f"Process started for Shard {shard_index} on GPU {gpu_id} (logical device {device})")
|
| 369 |
+
logger.info(f"Arguments: {args}")
|
| 370 |
+
|
| 371 |
+
processed_shard = None
|
| 372 |
+
processing_time = 0
|
| 373 |
+
|
| 374 |
+
try:
|
| 375 |
+
# --- 加载完整数据集 ---
|
| 376 |
+
logger.info(f"Loading dataset from {DATASET_PATH}")
|
| 377 |
+
try:
|
| 378 |
+
full_ds = load_from_disk(DATASET_PATH)
|
| 379 |
+
breakpoint()
|
| 380 |
+
logger.info(f"Full dataset loaded with {full_ds.num_rows} rows.")
|
| 381 |
+
# Check the feature type of question_audio - SHOULD BE string
|
| 382 |
+
if 'question_audio' not in full_ds.features:
|
| 383 |
+
logger.error("Dataset loaded, but required 'question_audio' column is missing!")
|
| 384 |
+
return
|
| 385 |
+
logger.info(f"Feature 'question_audio': {full_ds.features['question_audio']}")
|
| 386 |
+
if not isinstance(full_ds.features['question_audio'], Value) or full_ds.features['question_audio'].dtype != 'string':
|
| 387 |
+
logger.warning(f"'question_audio' column type is not string ({full_ds.features['question_audio']}). Attempting to proceed, but expecting paths.")
|
| 388 |
+
|
| 389 |
+
except Exception as e:
|
| 390 |
+
logger.error(f"Failed to load dataset: {e}", exc_info=True)
|
| 391 |
+
return
|
| 392 |
+
|
| 393 |
+
# --- 数据预处理:检查音频文件有效性 (on paths) ---
|
| 394 |
+
logger.info(f"Checking audio file validity (min duration: {MIN_AUDIO_DURATION_MS}ms)...")
|
| 395 |
+
check_features = Features({
|
| 396 |
+
**full_ds.features,
|
| 397 |
+
'audio_is_valid': Value('bool'),
|
| 398 |
+
'audio_check_error': Value('string'),
|
| 399 |
+
'audio_duration_ms': Value('int64')
|
| 400 |
+
})
|
| 401 |
+
num_check_workers = max(1, min(num_check_workers, os.cpu_count()))
|
| 402 |
+
logger.info(f"Using {num_check_workers} workers for audio check.")
|
| 403 |
+
full_ds_checked = full_ds.map(
|
| 404 |
+
check_audio_file_validity,
|
| 405 |
+
num_proc=num_check_workers,
|
| 406 |
+
features=check_features,
|
| 407 |
+
batched=False,
|
| 408 |
+
fn_kwargs={"shard_idx_for_log": shard_index}
|
| 409 |
+
)
|
| 410 |
+
logger.info("Audio validity check complete.")
|
| 411 |
+
|
| 412 |
+
# --- 过滤掉无效音频 ---
|
| 413 |
+
original_count = len(full_ds_checked)
|
| 414 |
+
valid_audio_ds = full_ds_checked.filter(
|
| 415 |
+
lambda x: x['audio_is_valid'],
|
| 416 |
+
num_proc=num_check_workers,
|
| 417 |
+
load_from_cache_file=False
|
| 418 |
+
)
|
| 419 |
+
filtered_count = original_count - len(valid_audio_ds)
|
| 420 |
+
logger.info(f"Filtered out {filtered_count} samples based on path validity/duration. Kept {len(valid_audio_ds)} samples.")
|
| 421 |
+
|
| 422 |
+
# Log filtering reasons (same as before)
|
| 423 |
+
if filtered_count > 0:
|
| 424 |
+
# Avoid running another potentially slow filter just for logging
|
| 425 |
+
logger.warning("Logging top filtering reasons (based on initial check results, sample limit applies if dataset large)...")
|
| 426 |
+
try:
|
| 427 |
+
error_reasons = collections.Counter(r['audio_check_error'] for r in full_ds_checked.filter(lambda x: not x['audio_is_valid'], load_from_cache_file=False).select(range(min(1000, filtered_count))))
|
| 428 |
+
for reason, count in error_reasons.most_common(10):
|
| 429 |
+
if reason: # Don't log None reasons if any slip through
|
| 430 |
+
logger.warning(f" - {reason}: {count} samples")
|
| 431 |
+
except Exception as log_e:
|
| 432 |
+
logger.warning(f"Could not retrieve filtering reasons: {log_e}")
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
if len(valid_audio_ds) == 0:
|
| 436 |
+
logger.error("No valid audio samples found after filtering. Exiting.")
|
| 437 |
+
return
|
| 438 |
+
|
| 439 |
+
# --- !! REMOVED cast_column step !! ---
|
| 440 |
+
# The 'question_audio' column remains as paths in valid_audio_ds
|
| 441 |
+
|
| 442 |
+
# --- 获取当前进程需要处理的分片 ---
|
| 443 |
+
logger.info(f"Creating shard {shard_index} from valid audio data (paths)...")
|
| 444 |
+
ds_shard = valid_audio_ds.shard(num_shards=NUM_SHARDS, index=shard_index, contiguous=True)
|
| 445 |
+
logger.info(f"Shard {shard_index} created with {ds_shard.num_rows} rows.")
|
| 446 |
+
# Log features to confirm 'question_audio' is still string
|
| 447 |
+
logger.info(f"Shard features: {ds_shard.features}")
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
if ds_shard.num_rows == 0:
|
| 451 |
+
logger.warning(f"Shard {shard_index} is empty after sharding. Saving empty structure and exiting process.")
|
| 452 |
+
# Define empty output features (keeping original path column)
|
| 453 |
+
final_features = Features({
|
| 454 |
+
**ds_shard.features, # Includes original columns like question_audio (path)
|
| 455 |
+
'asr_transcription': Value('string'),
|
| 456 |
+
'wer': Value('float32'),
|
| 457 |
+
'is_bad_tts': Value('bool'),
|
| 458 |
+
'error_message': Value('string')
|
| 459 |
+
})
|
| 460 |
+
# Remove check columns from features before creating empty table
|
| 461 |
+
final_features.pop('audio_is_valid', None)
|
| 462 |
+
final_features.pop('audio_check_error', None)
|
| 463 |
+
final_features.pop('audio_duration_ms', None)
|
| 464 |
+
|
| 465 |
+
shard_output_path = os.path.join(OUTPUT_DIR, f"shard_{shard_index}")
|
| 466 |
+
os.makedirs(shard_output_path, exist_ok=True)
|
| 467 |
+
try:
|
| 468 |
+
empty_table = pa.Table.from_pydict({}, schema=final_features.arrow_schema)
|
| 469 |
+
empty_ds = Dataset(arrow_table=empty_table)
|
| 470 |
+
empty_ds.save_to_disk(shard_output_path)
|
| 471 |
+
logger.info(f"Saved empty dataset structure for shard {shard_index}.")
|
| 472 |
+
except Exception as save_e:
|
| 473 |
+
logger.error(f"Could not save empty dataset structure for shard {shard_index}: {save_e}")
|
| 474 |
+
processed_shard = empty_ds # Set processed_shard for stats
|
| 475 |
+
return # Exit after handling empty shard
|
| 476 |
+
|
| 477 |
+
# --- 加载ASR Pipeline ---
|
| 478 |
+
logger.info(f"Loading ASR pipeline {MODEL_ID} on {device}...")
|
| 479 |
+
try:
|
| 480 |
+
asr_pipeline = pipeline(
|
| 481 |
+
"automatic-speech-recognition",
|
| 482 |
+
model=MODEL_ID,
|
| 483 |
+
torch_dtype=torch.float16,
|
| 484 |
+
device=device,
|
| 485 |
+
batch_size=pipeline_batch_size # Pipeline's internal batch size
|
| 486 |
+
)
|
| 487 |
+
logger.info(f"ASR pipeline loaded successfully with internal batch size {pipeline_batch_size}.")
|
| 488 |
+
except Exception as e:
|
| 489 |
+
logger.error(f"Failed to load ASR pipeline: {e}", exc_info=True)
|
| 490 |
+
return
|
| 491 |
+
|
| 492 |
+
# --- 使用 map 处理分片数据 ---
|
| 493 |
+
logger.info(f"Starting processing shard {shard_index} with map batch size {map_batch_size} and WER threshold {wer_threshold}...")
|
| 494 |
+
start_time = time.time()
|
| 495 |
+
# Define output features: Keep original columns + add new ones
|
| 496 |
+
output_features = Features({
|
| 497 |
+
**ds_shard.features, # Keep original columns (incl. question_audio path)
|
| 498 |
+
'asr_transcription': Value('string'),
|
| 499 |
+
'wer': Value('float32'),
|
| 500 |
+
'is_bad_tts': Value('bool'),
|
| 501 |
+
'error_message': Value('string')
|
| 502 |
+
})
|
| 503 |
+
# Remove check columns from output features
|
| 504 |
+
output_features.pop('audio_is_valid', None)
|
| 505 |
+
output_features.pop('audio_check_error', None)
|
| 506 |
+
output_features.pop('audio_duration_ms', None)
|
| 507 |
+
|
| 508 |
+
processed_shard = ds_shard.map(
|
| 509 |
+
check_audio_quality_batch,
|
| 510 |
+
batched=True,
|
| 511 |
+
batch_size=map_batch_size, # map's batch size (rows passed to func)
|
| 512 |
+
fn_kwargs={
|
| 513 |
+
"asr_pipeline": asr_pipeline,
|
| 514 |
+
"wer_threshold": wer_threshold,
|
| 515 |
+
"target_sr": TARGET_SR,
|
| 516 |
+
"shard_idx_for_log": shard_index
|
| 517 |
+
},
|
| 518 |
+
features=output_features, # Define output schema
|
| 519 |
+
load_from_cache_file=False, # Disable caching
|
| 520 |
+
remove_columns=['audio_is_valid', 'audio_check_error', 'audio_duration_ms'] # Remove check columns during map
|
| 521 |
+
)
|
| 522 |
+
end_time = time.time()
|
| 523 |
+
processing_time = end_time - start_time
|
| 524 |
+
logger.info(f"Shard {shard_index} processing finished in {processing_time:.2f} seconds.")
|
| 525 |
+
logger.info(f"Processed shard {shard_index} has columns: {processed_shard.column_names}")
|
| 526 |
+
|
| 527 |
+
# --- 保存处理后的分片 ---
|
| 528 |
+
# No need to remove check columns here, done in map
|
| 529 |
+
shard_output_path = os.path.join(OUTPUT_DIR, f"shard_{shard_index}")
|
| 530 |
+
logger.info(f"Saving processed shard {shard_index} to {shard_output_path}...")
|
| 531 |
+
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
| 532 |
+
try:
|
| 533 |
+
processed_shard.save_to_disk(shard_output_path)
|
| 534 |
+
logger.info(f"Shard {shard_index} saved successfully.")
|
| 535 |
+
except Exception as e:
|
| 536 |
+
# Check specifically for Arrow serialization issues if they occur
|
| 537 |
+
logger.error(f"Failed to save processed shard {shard_index} to {shard_output_path}: {e}", exc_info=True)
|
| 538 |
+
# The IndexError related to soundfile should NOT happen now
|
| 539 |
+
|
| 540 |
+
finally:
|
| 541 |
+
# --- 记录统计信息 ---
|
| 542 |
+
if processed_shard is not None:
|
| 543 |
+
log_shard_statistics(processed_shard, shard_index, wer_threshold, processing_time)
|
| 544 |
+
else:
|
| 545 |
+
logger.warning("Processing did not complete or failed early. No statistics to log.")
|
| 546 |
+
|
| 547 |
+
logger.info(f"Process for Shard {shard_index} on GPU {gpu_id} finished.")
|
| 548 |
+
if fh:
|
| 549 |
+
logger.removeHandler(fh)
|
| 550 |
+
fh.close()
|
| 551 |
+
|
| 552 |
+
if __name__ == "__main__":
|
| 553 |
+
# Add librosa to requirements check potentially
|
| 554 |
+
try:
|
| 555 |
+
import librosa
|
| 556 |
+
except ImportError:
|
| 557 |
+
print("Error: librosa is required. Please install it using: pip install librosa")
|
| 558 |
+
exit(1)
|
| 559 |
+
main()
|
r1-a/dataset/sciq.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import random
|
| 3 |
+
import torch
|
| 4 |
+
import torchaudio
|
| 5 |
+
from datasets import load_dataset, Dataset
|
| 6 |
+
import sys
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
|
| 9 |
+
sys.path.append('/root/autodl-tmp/CosyVoice')
|
| 10 |
+
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
|
| 11 |
+
from cosyvoice.utils.file_utils import load_wav
|
| 12 |
+
|
| 13 |
+
# ------------------------
|
| 14 |
+
# 配置参数
|
| 15 |
+
# ------------------------
|
| 16 |
+
COMMON_VOICE_LANGUAGE = "en"
|
| 17 |
+
DATASET_NAME = "sciq" # 目标数据集:SciQ
|
| 18 |
+
OUTPUT_DATASET_PATH = './sciq_with_audio'
|
| 19 |
+
SAMPLE_RATE = 16000
|
| 20 |
+
|
| 21 |
+
# ------------------------
|
| 22 |
+
# 辅助函数
|
| 23 |
+
# ------------------------
|
| 24 |
+
def get_random_prompt(common_voice_dataset, sample_rate=SAMPLE_RATE):
|
| 25 |
+
"""
|
| 26 |
+
从 VoxPopuli (替代原 Common Voice) 数据集中随机抽取一条语音及对应文本作为 prompt。
|
| 27 |
+
"""
|
| 28 |
+
idx = random.randint(0, len(common_voice_dataset) - 1)
|
| 29 |
+
sample = common_voice_dataset.select([idx])[0]
|
| 30 |
+
audio = sample['audio']
|
| 31 |
+
waveform = torch.tensor(audio['array'], dtype=torch.float32)
|
| 32 |
+
sr = audio['sampling_rate']
|
| 33 |
+
if sr != sample_rate:
|
| 34 |
+
resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate)
|
| 35 |
+
waveform = resampler(waveform)
|
| 36 |
+
return waveform.unsqueeze(0), sample['raw_text']
|
| 37 |
+
|
| 38 |
+
def text_to_audio(query_text, cosyvoice, common_voice_dataset, stream=False):
|
| 39 |
+
"""
|
| 40 |
+
利用 CosyVoice2 模型将输入文本转换为语音,采用随机 prompt 进行零样本推理。
|
| 41 |
+
"""
|
| 42 |
+
try:
|
| 43 |
+
prompt_speech, prompt_text = get_random_prompt(common_voice_dataset, sample_rate=SAMPLE_RATE)
|
| 44 |
+
# 可选:保存 prompt.wav 以做调试
|
| 45 |
+
# torchaudio.save('prompt.wav', prompt_speech, SAMPLE_RATE)
|
| 46 |
+
|
| 47 |
+
all_speech = []
|
| 48 |
+
for i, j in enumerate(cosyvoice.inference_zero_shot(
|
| 49 |
+
query_text,
|
| 50 |
+
prompt_text,
|
| 51 |
+
prompt_speech,
|
| 52 |
+
stream=stream,
|
| 53 |
+
text_frontend=False
|
| 54 |
+
)):
|
| 55 |
+
all_speech.append(j['tts_speech'])
|
| 56 |
+
|
| 57 |
+
# 将所有生成的语音片段拼接在一起
|
| 58 |
+
combined_speech = torch.cat(all_speech, dim=-1)
|
| 59 |
+
sample_rate_val = cosyvoice.sample_rate
|
| 60 |
+
|
| 61 |
+
return {
|
| 62 |
+
'audio_tensor': combined_speech,
|
| 63 |
+
'sample_rate': sample_rate_val
|
| 64 |
+
}
|
| 65 |
+
except Exception as e:
|
| 66 |
+
print(f"Error converting text to audio: {e}")
|
| 67 |
+
return None
|
| 68 |
+
|
| 69 |
+
def process_example(example, cosyvoice, common_voice_dataset, sample_rate=SAMPLE_RATE):
|
| 70 |
+
"""
|
| 71 |
+
针对 SciQ 数据集中的单个样本进行 TTS 处理。
|
| 72 |
+
假设我们只对 sample['question'] 做 TTS。
|
| 73 |
+
"""
|
| 74 |
+
query = example['question'] # 可根据需要修改要转换的文本字段
|
| 75 |
+
audio_result = text_to_audio(query, cosyvoice, common_voice_dataset, stream=False)
|
| 76 |
+
if audio_result is not None:
|
| 77 |
+
return {
|
| 78 |
+
'audio_tensor': audio_result['audio_tensor'],
|
| 79 |
+
'sample_rate': audio_result['sample_rate']
|
| 80 |
+
}
|
| 81 |
+
else:
|
| 82 |
+
return None
|
| 83 |
+
|
| 84 |
+
# ------------------------
|
| 85 |
+
# 数据加载与模型初始化
|
| 86 |
+
# ------------------------
|
| 87 |
+
print("Loading VoxPopuli (as Common Voice) dataset...")
|
| 88 |
+
common_voice = load_dataset("facebook/voxpopuli", "en", split='train')
|
| 89 |
+
print(f"Total VoxPopuli {COMMON_VOICE_LANGUAGE} samples: {len(common_voice)}")
|
| 90 |
+
|
| 91 |
+
print("Initializing CosyVoice2 model...")
|
| 92 |
+
cosyvoice = CosyVoice2(
|
| 93 |
+
'/root/autodl-tmp/CosyVoice/pretrained_models/CosyVoice2-0.5B', # 替换为实际模型路径
|
| 94 |
+
load_jit=True,
|
| 95 |
+
load_trt=False,
|
| 96 |
+
fp16=False
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
print("Loading SciQ dataset...")
|
| 100 |
+
dataset = load_dataset("allenai/sciq")
|
| 101 |
+
|
| 102 |
+
# 创建输出目录
|
| 103 |
+
os.makedirs(OUTPUT_DATASET_PATH, exist_ok=True)
|
| 104 |
+
|
| 105 |
+
# ------------------------
|
| 106 |
+
# 主处理循环
|
| 107 |
+
# ------------------------
|
| 108 |
+
final_dataset_dict = {} # 存放各 split 最终处理后的数据
|
| 109 |
+
|
| 110 |
+
for split_name, split_dataset in dataset.items():
|
| 111 |
+
print(f"Processing split: {split_name} with {len(split_dataset)} examples")
|
| 112 |
+
split_output_dir = os.path.join(OUTPUT_DATASET_PATH, split_name)
|
| 113 |
+
os.makedirs(split_output_dir, exist_ok=True)
|
| 114 |
+
|
| 115 |
+
# 用于断点续跑的进度记录
|
| 116 |
+
progress_file = os.path.join(split_output_dir, "progress.txt")
|
| 117 |
+
start_index = 0
|
| 118 |
+
if os.path.exists(progress_file):
|
| 119 |
+
try:
|
| 120 |
+
with open(progress_file, "r") as f:
|
| 121 |
+
start_index = int(f.read().strip())
|
| 122 |
+
print(f"Resuming split '{split_name}' from sample index {start_index}")
|
| 123 |
+
except Exception as e:
|
| 124 |
+
print(f"读取进度文件失败:{e}")
|
| 125 |
+
|
| 126 |
+
final_samples = [] # 用于存储处理后数据
|
| 127 |
+
|
| 128 |
+
# 遍历处理每条样本
|
| 129 |
+
for i in tqdm(range(len(split_dataset)), desc=f"Processing {split_name}"):
|
| 130 |
+
# 如果已处理过,就直接跳过并仅把已存在文件的元信息记入 final_samples
|
| 131 |
+
if i < start_index:
|
| 132 |
+
sample = split_dataset[i]
|
| 133 |
+
wav_path = os.path.join(split_output_dir, f"{i}.wav")
|
| 134 |
+
if os.path.exists(wav_path):
|
| 135 |
+
# 保留所有原始字段 + 音频路径
|
| 136 |
+
sample_dict = {k: sample[k] for k in sample.keys()}
|
| 137 |
+
sample_dict["audio_filepath"] = wav_path
|
| 138 |
+
final_samples.append(sample_dict)
|
| 139 |
+
continue
|
| 140 |
+
|
| 141 |
+
sample = split_dataset[i]
|
| 142 |
+
result = process_example(sample, cosyvoice, common_voice, sample_rate=SAMPLE_RATE)
|
| 143 |
+
|
| 144 |
+
if result is not None:
|
| 145 |
+
audio_tensor = result['audio_tensor']
|
| 146 |
+
if audio_tensor.dim() == 1:
|
| 147 |
+
audio_tensor = audio_tensor.unsqueeze(0)
|
| 148 |
+
sample_rate_val = result['sample_rate']
|
| 149 |
+
|
| 150 |
+
output_wav_path = os.path.join(split_output_dir, f"{i}.wav")
|
| 151 |
+
try:
|
| 152 |
+
torchaudio.save(output_wav_path, audio_tensor, sample_rate_val)
|
| 153 |
+
except Exception as e:
|
| 154 |
+
print(f"Failed to save wav for sample {i}: {e}")
|
| 155 |
+
continue
|
| 156 |
+
|
| 157 |
+
# 保留所有原始字段 + 生成的音频路径
|
| 158 |
+
sample_dict = {k: sample[k] for k in sample.keys()}
|
| 159 |
+
sample_dict["audio_filepath"] = output_wav_path
|
| 160 |
+
final_samples.append(sample_dict)
|
| 161 |
+
else:
|
| 162 |
+
print(f"Sample {i} processing failed, no audio generated.")
|
| 163 |
+
|
| 164 |
+
# 更新进度记录
|
| 165 |
+
with open(progress_file, "w") as f:
|
| 166 |
+
f.write(str(i + 1))
|
| 167 |
+
|
| 168 |
+
# 生成 Hugging Face Dataset 并落盘
|
| 169 |
+
final_dataset_obj = Dataset.from_list(final_samples)
|
| 170 |
+
final_dataset_save_path = os.path.join(split_output_dir, "final_dataset")
|
| 171 |
+
final_dataset_obj.save_to_disk(final_dataset_save_path)
|
| 172 |
+
|
| 173 |
+
print(f"Finished processing split: {split_name} with {len(final_samples)} final samples.")
|
| 174 |
+
final_dataset_dict[split_name] = final_dataset_obj
|
| 175 |
+
|
| 176 |
+
print("所有分割处理完毕,最终数据集已保存。")
|
r1-a/dataset/shp.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import os # 确保导入 os 用于保存
|
| 3 |
+
from datasets import load_dataset, Dataset
|
| 4 |
+
from tqdm.auto import tqdm # 用于显示进度条
|
| 5 |
+
|
| 6 |
+
# --- 可调整的过滤参数 ---
|
| 7 |
+
# (保持不变)
|
| 8 |
+
SCORE_RATIO_THRESHOLD = 2.0
|
| 9 |
+
MIN_CHOSEN_SCORE = 3
|
| 10 |
+
MIN_HISTORY_WORDS = 10
|
| 11 |
+
MAX_HISTORY_WORDS = 150 # 调整为 150
|
| 12 |
+
MAX_URLS = 0 # 调整为 0
|
| 13 |
+
MAX_NEWLINES = 5
|
| 14 |
+
FORBIDDEN_PATTERNS = [
|
| 15 |
+
r"```.*```",
|
| 16 |
+
r"\|.*\|.*\|",
|
| 17 |
+
]
|
| 18 |
+
MIN_RESPONSE_WORDS = 10
|
| 19 |
+
|
| 20 |
+
# --- 脚本主逻辑 ---
|
| 21 |
+
|
| 22 |
+
def is_tts_friendly(text):
|
| 23 |
+
"""检查文本是否大致适合 TTS"""
|
| 24 |
+
# (保持不变)
|
| 25 |
+
word_count = len(text.split())
|
| 26 |
+
if not (MIN_HISTORY_WORDS <= word_count <= MAX_HISTORY_WORDS):
|
| 27 |
+
return False
|
| 28 |
+
if text.count('http') > MAX_URLS: # 使用调整后的 MAX_URLS
|
| 29 |
+
return False
|
| 30 |
+
if text.count('\n') > MAX_NEWLINES:
|
| 31 |
+
return False
|
| 32 |
+
for pattern in FORBIDDEN_PATTERNS:
|
| 33 |
+
if re.search(pattern, text, re.DOTALL):
|
| 34 |
+
return False
|
| 35 |
+
return True
|
| 36 |
+
|
| 37 |
+
def filter_shp2_train_dataset(dataset_name="stanfordnlp/shp-2"): # 函数名稍作修改以反映其目的
|
| 38 |
+
"""
|
| 39 |
+
加载并过滤 SHP-2 数据集的 'train' split,
|
| 40 |
+
返回高质量、适合 TTS 的偏好对。
|
| 41 |
+
"""
|
| 42 |
+
split_to_process = 'train' # <--- 指定只处理 'train' split
|
| 43 |
+
print(f"加载数据集: {dataset_name}, split: {split_to_process}...")
|
| 44 |
+
|
| 45 |
+
try:
|
| 46 |
+
# --- 修改点 1: 直接加载指定的 split ---
|
| 47 |
+
train_dataset = load_dataset(dataset_name, split=split_to_process)
|
| 48 |
+
print(f"'{split_to_process}' split 加载完成。")
|
| 49 |
+
except Exception as e:
|
| 50 |
+
print(f"错误:无法加载数据集 {dataset_name} 的 '{split_to_process}' split。")
|
| 51 |
+
print(f"错误详情: {e}")
|
| 52 |
+
return [] # 返回空列表表示失败
|
| 53 |
+
|
| 54 |
+
filtered_data = []
|
| 55 |
+
seen_histories = set() # 用于跟踪已经添加的 history,确保唯一性
|
| 56 |
+
|
| 57 |
+
print(f"\n开始处理 '{split_to_process}' split...")
|
| 58 |
+
# --- 修改点 2: 直接迭代加载的 train_dataset ---
|
| 59 |
+
for example in tqdm(train_dataset, desc=f"过滤 {split_to_process} split"):
|
| 60 |
+
history = example.get("history")
|
| 61 |
+
human_ref_A = example.get("human_ref_A")
|
| 62 |
+
human_ref_B = example.get("human_ref_B")
|
| 63 |
+
labels = example.get("labels")
|
| 64 |
+
score_A = example.get("score_A")
|
| 65 |
+
score_B = example.get("score_B")
|
| 66 |
+
score_ratio = example.get("score_ratio")
|
| 67 |
+
domain = example.get("domain")
|
| 68 |
+
|
| 69 |
+
# 基本检查 (保持不变)
|
| 70 |
+
if not all([history, human_ref_A, human_ref_B, labels is not None,
|
| 71 |
+
score_A is not None, score_B is not None, score_ratio is not None, domain]):
|
| 72 |
+
continue
|
| 73 |
+
|
| 74 |
+
# 确保 history 未被处理过 (保持不变)
|
| 75 |
+
if history in seen_histories:
|
| 76 |
+
continue
|
| 77 |
+
|
| 78 |
+
# 确定 chosen 和 reject (保持不变)
|
| 79 |
+
try:
|
| 80 |
+
label_int = int(labels)
|
| 81 |
+
if label_int == 1:
|
| 82 |
+
chosen = human_ref_A
|
| 83 |
+
reject = human_ref_B
|
| 84 |
+
chosen_score = score_A
|
| 85 |
+
elif label_int == 0:
|
| 86 |
+
chosen = human_ref_B
|
| 87 |
+
reject = human_ref_A
|
| 88 |
+
chosen_score = score_B
|
| 89 |
+
else:
|
| 90 |
+
continue
|
| 91 |
+
except (ValueError, TypeError):
|
| 92 |
+
continue
|
| 93 |
+
|
| 94 |
+
# --- 应用过滤条件 (保持不变) ---
|
| 95 |
+
if score_ratio is None or not isinstance(score_ratio, (int, float)) or score_ratio < SCORE_RATIO_THRESHOLD:
|
| 96 |
+
continue
|
| 97 |
+
if chosen_score is None or not isinstance(chosen_score, (int, float)) or chosen_score < MIN_CHOSEN_SCORE:
|
| 98 |
+
continue
|
| 99 |
+
if not is_tts_friendly(history):
|
| 100 |
+
continue
|
| 101 |
+
if len(chosen.split()) < MIN_RESPONSE_WORDS or len(reject.split()) < MIN_RESPONSE_WORDS:
|
| 102 |
+
continue
|
| 103 |
+
|
| 104 |
+
# --- 如果所有过滤条件都通过 (保持不变) ---
|
| 105 |
+
filtered_data.append({
|
| 106 |
+
"query": history,
|
| 107 |
+
"chosen": chosen,
|
| 108 |
+
"reject": reject,
|
| 109 |
+
"domain": domain,
|
| 110 |
+
})
|
| 111 |
+
seen_histories.add(history)
|
| 112 |
+
|
| 113 |
+
print(f"\n过滤完成。从 '{split_to_process}' split 中总共筛选出 {len(filtered_data)} 条高质量样本。")
|
| 114 |
+
return filtered_data
|
| 115 |
+
|
| 116 |
+
# --- 主程序 ---
|
| 117 |
+
if __name__ == "__main__":
|
| 118 |
+
# 执行过滤 (调用修改后的函数)
|
| 119 |
+
filtered_examples = filter_shp2_train_dataset()
|
| 120 |
+
|
| 121 |
+
if filtered_examples:
|
| 122 |
+
# 将过滤后的数据转换为 Hugging Face Dataset 对象 (保持不变)
|
| 123 |
+
filtered_dataset = Dataset.from_list(filtered_examples)
|
| 124 |
+
|
| 125 |
+
# 保存过滤后的数据集 (保持不变)
|
| 126 |
+
output_path = "./shp2_filtered_tts_high_quality_train_only" # 修改输出路径以反映内容
|
| 127 |
+
print(f"正在保存过滤后的训练集数据到: {output_path}")
|
| 128 |
+
# 确保输出目录存在
|
| 129 |
+
os.makedirs(os.path.dirname(output_path), exist_ok=True) # 如果 output_path 是目录,这行不需要
|
| 130 |
+
filtered_dataset.save_to_disk(output_path)
|
| 131 |
+
print("数据集保存完成。")
|
| 132 |
+
|
| 133 |
+
# 打印一些样本看看 (保持不变)
|
| 134 |
+
print("\n部分样本预览:")
|
| 135 |
+
# 从保存的 Dataset 加载并预览,确保保存成功
|
| 136 |
+
try:
|
| 137 |
+
loaded_dataset = Dataset.load_from_disk(output_path)
|
| 138 |
+
for i in range(min(5, len(loaded_dataset))):
|
| 139 |
+
sample = loaded_dataset[i]
|
| 140 |
+
print(f"--- 样本 {i+1} ---")
|
| 141 |
+
print(f"Domain: {sample['domain']}")
|
| 142 |
+
print(f"Query: {sample['query'][:200]}...")
|
| 143 |
+
print(f"Chosen: {sample['chosen'][:200]}...")
|
| 144 |
+
except Exception as e:
|
| 145 |
+
print(f"加载预览样本时出错: {e}") # 增加错误处理
|
| 146 |
+
|
| 147 |
+
else:
|
| 148 |
+
print("没有找到符合条件的样本,请检查过滤参数设置或确认 'train' split 是否存在且包含数据。")
|
r1-a/dataset/shp_tts.py
ADDED
|
@@ -0,0 +1,494 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --- SET CUDA DEVICE ---
|
| 2 |
+
# Method 1: Set environment variable BEFORE importing torch/cosyvoice
|
| 3 |
+
# This makes only GPU 1 visible to the script, appearing as 'cuda:0' internally.
|
| 4 |
+
import os
|
| 5 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
| 6 |
+
# --- End CUDA Device Setting ---
|
| 7 |
+
|
| 8 |
+
import random
|
| 9 |
+
import torch
|
| 10 |
+
import torchaudio
|
| 11 |
+
# Make sure necessary types are imported
|
| 12 |
+
from datasets import load_dataset, Dataset, load_from_disk, Features, Value
|
| 13 |
+
import sys
|
| 14 |
+
from tqdm import tqdm
|
| 15 |
+
import time
|
| 16 |
+
import shutil # Added for potentially removing old dataset save dirs
|
| 17 |
+
|
| 18 |
+
# Check if the specified GPU is available after setting the environment variable
|
| 19 |
+
if not torch.cuda.is_available():
|
| 20 |
+
print("ERROR: CUDA is not available after setting CUDA_VISIBLE_DEVICES='1'. Check your PyTorch installation, GPU drivers, and that GPU 1 exists and is functional.")
|
| 21 |
+
# Force exit if the intended GPU is not found
|
| 22 |
+
sys.exit(1)
|
| 23 |
+
else:
|
| 24 |
+
# Since CUDA_VISIBLE_DEVICES is set to '1', the first *visible* device is cuda:0
|
| 25 |
+
effective_device = torch.device("cuda:0")
|
| 26 |
+
try:
|
| 27 |
+
print(f"CUDA device visible to PyTorch: {torch.cuda.get_device_name(0)}") # Should show the name of GPU 1
|
| 28 |
+
print(f"Script will effectively run on: {effective_device}")
|
| 29 |
+
# Perform a small check to ensure the device is usable
|
| 30 |
+
_ = torch.tensor([1.0]).to(effective_device)
|
| 31 |
+
print("Device check successful.")
|
| 32 |
+
except Exception as e:
|
| 33 |
+
print(f"ERROR: Failed CUDA device check for visible device 'cuda:0' (original GPU 1): {e}")
|
| 34 |
+
sys.exit(1)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# Ensure CosyVoice path is correct
|
| 38 |
+
COSYVOICE_PATH = '/home/chenyifu/CosyVoice' # Make sure this path is correct
|
| 39 |
+
if not os.path.isdir(COSYVOICE_PATH):
|
| 40 |
+
print(f"ERROR: CosyVoice path not found: {COSYVOICE_PATH}")
|
| 41 |
+
sys.exit(1)
|
| 42 |
+
sys.path.append(COSYVOICE_PATH)
|
| 43 |
+
|
| 44 |
+
# Import CosyVoice *after* setting the environment variable
|
| 45 |
+
try:
|
| 46 |
+
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
|
| 47 |
+
from cosyvoice.utils.file_utils import load_wav
|
| 48 |
+
print("CosyVoice imported successfully.")
|
| 49 |
+
except ImportError as e:
|
| 50 |
+
print(f"Error importing CosyVoice: {e}")
|
| 51 |
+
print(f"Please ensure the path '{COSYVOICE_PATH}' is correct and the library is installed within that directory.")
|
| 52 |
+
sys.exit(1)
|
| 53 |
+
except Exception as e:
|
| 54 |
+
print(f"An unexpected error occurred during CosyVoice import: {e}")
|
| 55 |
+
sys.exit(1)
|
| 56 |
+
|
| 57 |
+
# ------------------------
|
| 58 |
+
# 配置参数 (MODIFIED FOR NEW DATASET)
|
| 59 |
+
# ------------------------
|
| 60 |
+
COMMON_VOICE_LANGUAGE = "en" # Language for prompts
|
| 61 |
+
|
| 62 |
+
# --- !! MODIFIED !! ---
|
| 63 |
+
# Input: Path to the dataset created by the previous selection script
|
| 64 |
+
INPUT_DATASET_PATH = "./shp2_final_top20_percent/train_split_top20_percent_by_complexity"
|
| 65 |
+
# Output: Directory to save new audio files and the final dataset object
|
| 66 |
+
OUTPUT_DATASET_PATH = './shp2_top20_percent_with_query_audio'
|
| 67 |
+
# --- End MODIFIED ---
|
| 68 |
+
|
| 69 |
+
SAMPLE_RATE = 16000 # Target sample rate for TTS output (should match CosyVoice default)
|
| 70 |
+
MAX_TTS_RETRIES = 3
|
| 71 |
+
RETRY_DELAY_SECONDS = 3 # Slightly increased delay
|
| 72 |
+
|
| 73 |
+
# ------------------------
|
| 74 |
+
# 辅助函数 (GPU handling and core TTS logic - UNCHANGED as requested)
|
| 75 |
+
# ------------------------
|
| 76 |
+
def get_random_prompt(common_voice_dataset, sample_rate=SAMPLE_RATE):
|
| 77 |
+
"""
|
| 78 |
+
从 VoxPopuli 数据集中随机抽取一条语音及对应文本作为 prompt。
|
| 79 |
+
(Logic remains unchanged)
|
| 80 |
+
"""
|
| 81 |
+
idx = random.randint(0, len(common_voice_dataset) - 1)
|
| 82 |
+
try:
|
| 83 |
+
# Use select().with_format('numpy') for potentially better memory handling with large datasets
|
| 84 |
+
sample = common_voice_dataset.select([idx]).with_format('numpy')[0]
|
| 85 |
+
audio = sample['audio']
|
| 86 |
+
waveform = torch.tensor(audio['array'], dtype=torch.float32) # Created on CPU
|
| 87 |
+
sr = audio['sampling_rate']
|
| 88 |
+
|
| 89 |
+
if sr != sample_rate:
|
| 90 |
+
# Ensure waveform is 1D before resampling
|
| 91 |
+
if waveform.dim() > 1:
|
| 92 |
+
waveform = waveform.mean(dim=0)
|
| 93 |
+
if waveform.dim() != 1:
|
| 94 |
+
print(f"Warning: Unexpected waveform dimension {waveform.dim()} before resampling. Skipping prompt.")
|
| 95 |
+
return get_random_prompt(common_voice_dataset, sample_rate) # Retry
|
| 96 |
+
|
| 97 |
+
resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate)
|
| 98 |
+
waveform = resampler(waveform)
|
| 99 |
+
|
| 100 |
+
# Ensure output is 2D [1, T]
|
| 101 |
+
if waveform.dim() == 1:
|
| 102 |
+
waveform = waveform.unsqueeze(0)
|
| 103 |
+
elif waveform.dim() > 2:
|
| 104 |
+
print(f"Warning: Unexpected waveform dimension {waveform.dim()} after resampling. Skipping prompt.")
|
| 105 |
+
return get_random_prompt(common_voice_dataset, sample_rate) # Retry
|
| 106 |
+
|
| 107 |
+
raw_text = sample.get('raw_text', '')
|
| 108 |
+
if waveform.numel() == 0 or not raw_text or not raw_text.strip():
|
| 109 |
+
# print("Warning: Got an empty audio or text prompt, trying again...")
|
| 110 |
+
return get_random_prompt(common_voice_dataset, sample_rate) # Retry
|
| 111 |
+
|
| 112 |
+
# Return CPU tensor, CosyVoice inference should handle moving it
|
| 113 |
+
return waveform, raw_text
|
| 114 |
+
except Exception as e:
|
| 115 |
+
print(f"Error getting random prompt at index {idx}: {e}. Retrying...")
|
| 116 |
+
time.sleep(0.1) # Small delay before retry
|
| 117 |
+
return get_random_prompt(common_voice_dataset, sample_rate)
|
| 118 |
+
|
| 119 |
+
def text_to_audio(text_to_convert, cosyvoice, common_voice_dataset, stream=False, max_retries=MAX_TTS_RETRIES):
|
| 120 |
+
"""
|
| 121 |
+
利用 CosyVoice2 模型将输入文本转换为语音,采用随机 prompt 进行零样本推理。
|
| 122 |
+
Includes retry logic on failure. Assumes cosyvoice runs on the configured device.
|
| 123 |
+
(Logic remains unchanged)
|
| 124 |
+
"""
|
| 125 |
+
last_exception = None
|
| 126 |
+
prompt_speech = None
|
| 127 |
+
prompt_text = "N/A"
|
| 128 |
+
|
| 129 |
+
for attempt in range(max_retries):
|
| 130 |
+
try:
|
| 131 |
+
# Get prompt - ensures it's valid this time
|
| 132 |
+
prompt_speech, prompt_text = get_random_prompt(common_voice_dataset, sample_rate=SAMPLE_RATE)
|
| 133 |
+
# prompt_speech is initially on CPU
|
| 134 |
+
|
| 135 |
+
all_speech = []
|
| 136 |
+
# cosyvoice.inference_zero_shot should internally use the GPU device it was initialized on
|
| 137 |
+
# (which should be the visible cuda:0, i.e., original cuda:1)
|
| 138 |
+
inference_generator = cosyvoice.inference_zero_shot(
|
| 139 |
+
text_to_convert,
|
| 140 |
+
prompt_text,
|
| 141 |
+
prompt_speech, # Pass CPU tensor
|
| 142 |
+
stream=stream,
|
| 143 |
+
text_frontend=False # Assuming default frontend is desired
|
| 144 |
+
)
|
| 145 |
+
# Generated chunks 'tts_speech' will be on the GPU
|
| 146 |
+
for i, chunk in enumerate(inference_generator):
|
| 147 |
+
if chunk is None:
|
| 148 |
+
print(f"Warning: Received None chunk {i} during TTS generation for text '{text_to_convert[:60]}...'")
|
| 149 |
+
continue
|
| 150 |
+
if 'tts_speech' in chunk and chunk['tts_speech'] is not None and chunk['tts_speech'].numel() > 0:
|
| 151 |
+
# Ensure chunk is on the correct device (should be already, but belt-and-suspenders)
|
| 152 |
+
gpu_chunk = chunk['tts_speech'].to(effective_device)
|
| 153 |
+
all_speech.append(gpu_chunk)
|
| 154 |
+
# else: # Reduce log noise
|
| 155 |
+
# print(f"Debug: Chunk {i} missing 'tts_speech' or is empty for text '{text_to_convert[:60]}...'")
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
if not all_speech:
|
| 159 |
+
# Clear GPU memory cache if an error occurs during generation
|
| 160 |
+
if torch.cuda.is_available(): torch.cuda.empty_cache()
|
| 161 |
+
raise ValueError("TTS inference finished but produced no valid audio chunks.")
|
| 162 |
+
|
| 163 |
+
# combined_speech is on GPU
|
| 164 |
+
combined_speech = torch.cat(all_speech, dim=-1)
|
| 165 |
+
sample_rate_val = cosyvoice.sample_rate
|
| 166 |
+
|
| 167 |
+
# --- Add a check for silence ---
|
| 168 |
+
# Check max absolute amplitude; threshold might need tuning
|
| 169 |
+
if torch.max(torch.abs(combined_speech)) < 0.001:
|
| 170 |
+
print(f"Warning: Generated audio appears to be silent for text '{text_to_convert[:60]}...'. Retrying...")
|
| 171 |
+
raise ValueError("Generated audio is silent")
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
return {
|
| 175 |
+
# Return GPU tensor, will be moved to CPU before saving
|
| 176 |
+
'audio_tensor': combined_speech,
|
| 177 |
+
'sample_rate': sample_rate_val
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
except Exception as e:
|
| 181 |
+
last_exception = e
|
| 182 |
+
print(f"Error converting text to audio on attempt {attempt + 1}/{max_retries}: {e}")
|
| 183 |
+
print(f" Text: '{text_to_convert[:100]}...'")
|
| 184 |
+
print(f" Prompt Text Used: '{prompt_text[:100]}...'")
|
| 185 |
+
# Clear GPU cache on error
|
| 186 |
+
if torch.cuda.is_available(): torch.cuda.empty_cache()
|
| 187 |
+
if attempt < max_retries - 1:
|
| 188 |
+
print(f" Retrying with a different prompt in {RETRY_DELAY_SECONDS}s...")
|
| 189 |
+
time.sleep(RETRY_DELAY_SECONDS)
|
| 190 |
+
else:
|
| 191 |
+
print(f" All {max_retries} TTS attempts failed.")
|
| 192 |
+
|
| 193 |
+
print(f"Failed to generate audio for text after {max_retries} attempts: '{text_to_convert[:60]}...'")
|
| 194 |
+
if last_exception:
|
| 195 |
+
print(f"Last error: {last_exception}")
|
| 196 |
+
# Explicitly return None on failure
|
| 197 |
+
return None
|
| 198 |
+
|
| 199 |
+
# --- !! MODIFIED process_example !! ---
|
| 200 |
+
def process_example(example, cosyvoice, common_voice_dataset, sample_rate=SAMPLE_RATE):
|
| 201 |
+
"""
|
| 202 |
+
针对从磁盘加载的 *SHP-2 Top 20%* 数据集中的单个样本进行 TTS 处理。
|
| 203 |
+
Processes the example['query'] field.
|
| 204 |
+
"""
|
| 205 |
+
# --- MODIFIED: Target the 'query' field ---
|
| 206 |
+
text_to_convert = example.get('query')
|
| 207 |
+
# --- End MODIFIED ---
|
| 208 |
+
|
| 209 |
+
if not text_to_convert or not isinstance(text_to_convert, str) or text_to_convert.strip() == "":
|
| 210 |
+
# --- MODIFIED: Update warning message ---
|
| 211 |
+
print(f"Warning: Skipping example due to missing or empty 'query' field. Keys: {list(example.keys())}")
|
| 212 |
+
# --- End MODIFIED ---
|
| 213 |
+
return None
|
| 214 |
+
|
| 215 |
+
# Call the unchanged text_to_audio function
|
| 216 |
+
audio_result = text_to_audio(text_to_convert, cosyvoice, common_voice_dataset, stream=False)
|
| 217 |
+
|
| 218 |
+
if audio_result is not None:
|
| 219 |
+
audio_tensor = audio_result['audio_tensor'] # Still on GPU here
|
| 220 |
+
# Basic validation of the tensor
|
| 221 |
+
if audio_tensor is None or audio_tensor.numel() == 0:
|
| 222 |
+
print(f"Warning: TTS process returned empty tensor for query: '{text_to_convert[:60]}...'")
|
| 223 |
+
return None
|
| 224 |
+
|
| 225 |
+
# Ensure correct shape (should be [1, T] from text_to_audio)
|
| 226 |
+
if audio_tensor.dim() == 1:
|
| 227 |
+
audio_tensor = audio_tensor.unsqueeze(0)
|
| 228 |
+
elif audio_tensor.dim() > 2:
|
| 229 |
+
print(f"Warning: Generated audio tensor has unexpected shape {audio_tensor.shape}. Attempting to flatten.")
|
| 230 |
+
audio_tensor = audio_tensor.view(1, -1) # Flatten to [1, T]
|
| 231 |
+
|
| 232 |
+
# Double-check for emptiness after potential reshape
|
| 233 |
+
if audio_tensor.numel() == 0:
|
| 234 |
+
print(f"Warning: Generated audio tensor became empty after reshape for query: '{text_to_convert[:60]}...'")
|
| 235 |
+
return None
|
| 236 |
+
|
| 237 |
+
return {
|
| 238 |
+
'audio_tensor': audio_tensor, # Return GPU tensor
|
| 239 |
+
'sample_rate': audio_result['sample_rate']
|
| 240 |
+
}
|
| 241 |
+
else:
|
| 242 |
+
# text_to_audio already prints detailed errors
|
| 243 |
+
return None
|
| 244 |
+
|
| 245 |
+
# ------------------------
|
| 246 |
+
# 数据加载与模型初始化
|
| 247 |
+
# ------------------------
|
| 248 |
+
print("Loading VoxPopuli (as Common Voice) dataset for prompts...")
|
| 249 |
+
try:
|
| 250 |
+
# Load prompt dataset to CPU memory
|
| 251 |
+
common_voice = load_dataset("facebook/voxpopuli", COMMON_VOICE_LANGUAGE, split='train', trust_remote_code=True)
|
| 252 |
+
# Filter potentially problematic samples (optional, but can help)
|
| 253 |
+
common_voice = common_voice.filter(lambda x: x['audio'] is not None and x['audio']['array'] is not None and x['raw_text'] is not None and len(x['raw_text'].strip()) > 5 and len(x['audio']['array']) > SAMPLE_RATE * 0.5) # Min 0.5 sec, non-empty text
|
| 254 |
+
print(f"Loaded and filtered VoxPopuli '{COMMON_VOICE_LANGUAGE}' samples: {len(common_voice)}")
|
| 255 |
+
if len(common_voice) == 0:
|
| 256 |
+
raise ValueError(f"VoxPopuli dataset '{COMMON_VOICE_LANGUAGE}' loaded but contains no valid samples after filtering.")
|
| 257 |
+
except Exception as e:
|
| 258 |
+
print(f"Error loading or filtering VoxPopuli dataset: {e}")
|
| 259 |
+
sys.exit(1)
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
print("Initializing CosyVoice2 model...")
|
| 263 |
+
try:
|
| 264 |
+
# CosyVoice should automatically initialize on the visible device ('cuda:0', which is original 'cuda:1')
|
| 265 |
+
cosyvoice_model_path = os.path.join(COSYVOICE_PATH, 'pretrained_models/CosyVoice2-0.5B')
|
| 266 |
+
if not os.path.isdir(cosyvoice_model_path):
|
| 267 |
+
print(f"ERROR: CosyVoice pretrained model directory not found: {cosyvoice_model_path}")
|
| 268 |
+
sys.exit(1)
|
| 269 |
+
|
| 270 |
+
cosyvoice = CosyVoice2(
|
| 271 |
+
cosyvoice_model_path,
|
| 272 |
+
load_jit=True, # Assuming JIT is preferred
|
| 273 |
+
load_trt=False, # Ensure TRT is False if not set up for GPU 1
|
| 274 |
+
fp16=False # Keep FP16 False unless GPU 1 is known to handle it well and has enough VRAM
|
| 275 |
+
# device=effective_device # Usually not needed if CUDA_VISIBLE_DEVICES is set
|
| 276 |
+
)
|
| 277 |
+
print(f"CosyVoice model initialized. Target device: {effective_device}")
|
| 278 |
+
# Verify model is on the correct device (optional check)
|
| 279 |
+
if hasattr(cosyvoice, 'model') and hasattr(cosyvoice.model, 'device'):
|
| 280 |
+
print(f"CosyVoice internal model device: {cosyvoice.model.device}")
|
| 281 |
+
elif hasattr(cosyvoice, 'device'):
|
| 282 |
+
print(f"CosyVoice main object device: {cosyvoice.device}")
|
| 283 |
+
|
| 284 |
+
except Exception as e:
|
| 285 |
+
print(f"Error initializing CosyVoice2 model: {e}")
|
| 286 |
+
if isinstance(e, RuntimeError) and 'CUDA' in str(e):
|
| 287 |
+
print("This might be a CUDA initialization error. Ensure GPU 1 is functional, has enough memory, and required CUDA toolkit versions are compatible.")
|
| 288 |
+
sys.exit(1)
|
| 289 |
+
|
| 290 |
+
# --- !! MODIFIED Dataset Loading !! ---
|
| 291 |
+
print(f"\nLoading the target dataset from disk: {INPUT_DATASET_PATH}")
|
| 292 |
+
if not os.path.exists(INPUT_DATASET_PATH):
|
| 293 |
+
print(f"Error: Input dataset directory not found at '{INPUT_DATASET_PATH}'.")
|
| 294 |
+
print("Please ensure the previous (selection) script ran successfully and produced the dataset at this location.")
|
| 295 |
+
sys.exit(1)
|
| 296 |
+
|
| 297 |
+
try:
|
| 298 |
+
input_dataset = load_from_disk(INPUT_DATASET_PATH)
|
| 299 |
+
|
| 300 |
+
print(f"Successfully loaded dataset with {len(input_dataset)} examples.")
|
| 301 |
+
if len(input_dataset) == 0:
|
| 302 |
+
print("Error: The loaded dataset is empty. Cannot proceed.")
|
| 303 |
+
sys.exit(1)
|
| 304 |
+
# Store original features to reconstruct the final dataset correctly
|
| 305 |
+
original_features = input_dataset.features
|
| 306 |
+
print(f"Original features: {original_features}")
|
| 307 |
+
# Check for 'query' column existence
|
| 308 |
+
if 'query' not in original_features:
|
| 309 |
+
print(f"Error: The loaded dataset from '{INPUT_DATASET_PATH}' does not contain the required 'query' column.")
|
| 310 |
+
sys.exit(1)
|
| 311 |
+
|
| 312 |
+
except Exception as e:
|
| 313 |
+
print(f"Error loading dataset from '{INPUT_DATASET_PATH}': {e}")
|
| 314 |
+
sys.exit(1)
|
| 315 |
+
# --- End MODIFIED Dataset Loading ---
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
# --- Create output directories ---
|
| 319 |
+
os.makedirs(OUTPUT_DATASET_PATH, exist_ok=True)
|
| 320 |
+
# Subdirectory for the actual audio files
|
| 321 |
+
audio_output_dir = os.path.join(OUTPUT_DATASET_PATH, "audio_files")
|
| 322 |
+
os.makedirs(audio_output_dir, exist_ok=True)
|
| 323 |
+
print(f"Audio files will be saved in: {audio_output_dir}")
|
| 324 |
+
# Path for the progress tracking file
|
| 325 |
+
progress_file = os.path.join(OUTPUT_DATASET_PATH, "progress.txt")
|
| 326 |
+
print(f"Progress will be tracked in: {progress_file}")
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
# ------------------------
|
| 330 |
+
# 主处理循环 (MODIFIED FOR SINGLE DATASET)
|
| 331 |
+
# ------------------------
|
| 332 |
+
print(f"\nStarting TTS processing for {len(input_dataset)} samples...")
|
| 333 |
+
|
| 334 |
+
start_index = 0
|
| 335 |
+
# Read progress file to resume if necessary
|
| 336 |
+
if os.path.exists(progress_file):
|
| 337 |
+
try:
|
| 338 |
+
with open(progress_file, "r") as f:
|
| 339 |
+
content = f.read().strip()
|
| 340 |
+
if content:
|
| 341 |
+
start_index = int(content)
|
| 342 |
+
print(f"Resuming TTS processing from sample index {start_index}")
|
| 343 |
+
else:
|
| 344 |
+
print(f"Progress file '{progress_file}' is empty, starting TTS from index 0.")
|
| 345 |
+
start_index = 0
|
| 346 |
+
except ValueError:
|
| 347 |
+
print(f"Could not parse integer from progress file '{progress_file}'. Starting TTS from index 0.")
|
| 348 |
+
start_index = 0
|
| 349 |
+
except Exception as e:
|
| 350 |
+
print(f"Error reading progress file '{progress_file}': {e}. Starting TTS from index 0.")
|
| 351 |
+
start_index = 0
|
| 352 |
+
|
| 353 |
+
# List to hold dictionaries for the final dataset
|
| 354 |
+
final_samples = []
|
| 355 |
+
|
| 356 |
+
# --- Main Loop ---
|
| 357 |
+
pbar = tqdm(range(start_index, len(input_dataset)), desc=f"TTS on 'query' field", initial=start_index, total=len(input_dataset))
|
| 358 |
+
for i in pbar:
|
| 359 |
+
sample = input_dataset[i] # Get sample dictionary (on CPU)
|
| 360 |
+
|
| 361 |
+
# Define unique output WAV path using the index
|
| 362 |
+
# Using index is simple, assumes dataset order is stable during processing
|
| 363 |
+
output_wav_filename = f"query_{i}.wav"
|
| 364 |
+
output_wav_path = os.path.join(audio_output_dir, output_wav_filename)
|
| 365 |
+
|
| 366 |
+
# --- Check if audio file already exists ---
|
| 367 |
+
if os.path.exists(output_wav_path):
|
| 368 |
+
# If already processed, create the dict for the final dataset
|
| 369 |
+
sample_dict = dict(sample) # Copy original data
|
| 370 |
+
sample_dict["query_audio_filepath"] = output_wav_path # Add the path field
|
| 371 |
+
final_samples.append(sample_dict)
|
| 372 |
+
# Update progress file even when skipping (to ensure it reflects the latest processed/checked index)
|
| 373 |
+
with open(progress_file, "w") as f:
|
| 374 |
+
f.write(str(i + 1))
|
| 375 |
+
continue # Skip TTS for this sample
|
| 376 |
+
|
| 377 |
+
# --- Perform TTS on the target device ---
|
| 378 |
+
# process_example handles getting the 'query' text and calling text_to_audio
|
| 379 |
+
result = process_example(sample, cosyvoice, common_voice, sample_rate=SAMPLE_RATE)
|
| 380 |
+
|
| 381 |
+
if result is not None and 'audio_tensor' in result and result['audio_tensor'] is not None:
|
| 382 |
+
audio_tensor = result['audio_tensor'] # Received tensor is on GPU
|
| 383 |
+
sample_rate_val = result['sample_rate']
|
| 384 |
+
|
| 385 |
+
try:
|
| 386 |
+
# --- Move tensor to CPU before saving ---
|
| 387 |
+
audio_tensor_save = audio_tensor.detach().cpu().to(torch.float32)
|
| 388 |
+
|
| 389 |
+
# Ensure shape is 2D [1, T] for torchaudio.save
|
| 390 |
+
if audio_tensor_save.dim() == 1:
|
| 391 |
+
audio_tensor_save = audio_tensor_save.unsqueeze(0)
|
| 392 |
+
elif audio_tensor_save.dim() > 2:
|
| 393 |
+
print(f"Warning: Flattening unexpected tensor shape {audio_tensor_save.shape} before saving.")
|
| 394 |
+
audio_tensor_save = audio_tensor_save.view(1, -1)
|
| 395 |
+
|
| 396 |
+
# Save the audio file
|
| 397 |
+
torchaudio.save(output_wav_path, audio_tensor_save, sample_rate_val)
|
| 398 |
+
|
| 399 |
+
# Create dict for the final dataset including the new path
|
| 400 |
+
sample_dict = dict(sample) # Copy original data
|
| 401 |
+
sample_dict["query_audio_filepath"] = output_wav_path # Add the path field
|
| 402 |
+
final_samples.append(sample_dict)
|
| 403 |
+
|
| 404 |
+
# --- Explicitly delete GPU tensor ---
|
| 405 |
+
del audio_tensor
|
| 406 |
+
# No need to delete audio_tensor_save as it's on CPU
|
| 407 |
+
|
| 408 |
+
except Exception as e:
|
| 409 |
+
print(f"Failed to save wav for sample {i} ('query' field TTS) at {output_wav_path}: {e}")
|
| 410 |
+
# Attempt to remove partially saved/corrupted file if save failed
|
| 411 |
+
if os.path.exists(output_wav_path):
|
| 412 |
+
try: os.remove(output_wav_path)
|
| 413 |
+
except OSError: pass
|
| 414 |
+
# Clear cache on save error too
|
| 415 |
+
if torch.cuda.is_available(): torch.cuda.empty_cache()
|
| 416 |
+
else:
|
| 417 |
+
# Log failure (process_example or text_to_audio already logged details)
|
| 418 |
+
query_text = sample.get('query', 'N/A')
|
| 419 |
+
print(f"Sample {i} TTS failed or produced no audio after retries (Query Text: '{query_text[:60]}...'). Audio file not saved.")
|
| 420 |
+
# Ensure cache is cleared even on TTS failure
|
| 421 |
+
if torch.cuda.is_available(): torch.cuda.empty_cache()
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
# --- Update progress file ---
|
| 425 |
+
# Write the index of the *next* sample to start from if resuming
|
| 426 |
+
with open(progress_file, "w") as f:
|
| 427 |
+
f.write(str(i + 1))
|
| 428 |
+
|
| 429 |
+
# --- Optional: Periodic cache clearing ---
|
| 430 |
+
if i > 0 and i % 50 == 0: # Example: clear cache every 50 iterations (adjust as needed)
|
| 431 |
+
if torch.cuda.is_available():
|
| 432 |
+
# print(f"Clearing CUDA cache at iteration {i}...") # Debug log
|
| 433 |
+
torch.cuda.empty_cache()
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
# --- Final cache clear after finishing the loop ---
|
| 437 |
+
if torch.cuda.is_available():
|
| 438 |
+
print("Clearing final CUDA cache...")
|
| 439 |
+
torch.cuda.empty_cache()
|
| 440 |
+
|
| 441 |
+
# ------------------------
|
| 442 |
+
# 保存最终数据集 (MODIFIED)
|
| 443 |
+
# ------------------------
|
| 444 |
+
print("\nTTS processing loop finished.")
|
| 445 |
+
if final_samples:
|
| 446 |
+
print(f"Successfully processed (or skipped existing) {len(final_samples)} samples.")
|
| 447 |
+
|
| 448 |
+
# --- Define features for the new dataset ---
|
| 449 |
+
# Start with original features and add the new audio path column
|
| 450 |
+
new_features_dict = original_features.copy()
|
| 451 |
+
new_column_name = 'query_audio_filepath'
|
| 452 |
+
if new_column_name in new_features_dict:
|
| 453 |
+
print(f"Warning: Feature '{new_column_name}' already exists in original features. Overwriting.")
|
| 454 |
+
new_features_dict[new_column_name] = Value('string') # Add the new column definition
|
| 455 |
+
try:
|
| 456 |
+
new_features = Features(new_features_dict)
|
| 457 |
+
print(f"Defined new features for saving: {new_features}")
|
| 458 |
+
|
| 459 |
+
# --- Create the final Dataset object ---
|
| 460 |
+
print("Creating final Dataset object from processed samples...")
|
| 461 |
+
final_dataset_obj = Dataset.from_list(final_samples, features=new_features)
|
| 462 |
+
|
| 463 |
+
# --- Define path to save the final dataset metadata object ---
|
| 464 |
+
# This object contains the original data + the new filepath column
|
| 465 |
+
final_dataset_save_path = os.path.join(OUTPUT_DATASET_PATH, "processed_dataset_with_audio")
|
| 466 |
+
print(f"Saving final dataset metadata (with audio paths) to: {final_dataset_save_path}...")
|
| 467 |
+
|
| 468 |
+
# Ensure the target directory exists and is empty before saving
|
| 469 |
+
if os.path.exists(final_dataset_save_path):
|
| 470 |
+
print(f"Removing existing directory before saving: {final_dataset_save_path}")
|
| 471 |
+
shutil.rmtree(final_dataset_save_path)
|
| 472 |
+
# The save_to_disk function will create the directory
|
| 473 |
+
# os.makedirs(os.path.dirname(final_dataset_save_path), exist_ok=True) # Not needed if saving to the dir itself
|
| 474 |
+
|
| 475 |
+
final_dataset_obj.save_to_disk(final_dataset_save_path)
|
| 476 |
+
print(f"Final dataset object saved successfully.")
|
| 477 |
+
|
| 478 |
+
except Exception as e:
|
| 479 |
+
print(f"\n!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
|
| 480 |
+
print(f"Error during final dataset creation or saving: {e}")
|
| 481 |
+
print(f"Audio files might be saved in '{audio_output_dir}', but the final dataset object could not be created/saved.")
|
| 482 |
+
print(f"Check the features and the content of 'final_samples'.")
|
| 483 |
+
print(f"!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
|
| 484 |
+
|
| 485 |
+
else:
|
| 486 |
+
print("Processing finished, but no samples were successfully processed or had existing audio files.")
|
| 487 |
+
print(f"Check logs for TTS errors. Audio files directory: '{audio_output_dir}'.")
|
| 488 |
+
|
| 489 |
+
|
| 490 |
+
print("\n" + "="*60)
|
| 491 |
+
print(f"Script finished.")
|
| 492 |
+
print(f"Generated audio files are located in: '{audio_output_dir}'")
|
| 493 |
+
print(f"The final dataset (metadata + audio file paths) is saved at: '{os.path.join(OUTPUT_DATASET_PATH, 'processed_dataset_with_audio')}' (if saving was successful)")
|
| 494 |
+
print("="*60)
|
r1-a/dataset/ultrachat.py
ADDED
|
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import os
|
| 3 |
+
from datasets import load_dataset, Dataset
|
| 4 |
+
from tqdm.auto import tqdm
|
| 5 |
+
import json
|
| 6 |
+
import string # 引入 string 模块用于字符检查
|
| 7 |
+
|
| 8 |
+
# --- 可调整的过滤参数 ---
|
| 9 |
+
MIN_USER_QUERY_WORDS = 5
|
| 10 |
+
MAX_USER_QUERY_WORDS = 150
|
| 11 |
+
SIMPLE_PROMPT_PATTERNS = [
|
| 12 |
+
r"^\s*(ok|yes|no|thanks?|got it|great|cool|sounds good|perfect|alright|fine|bye|goodbye)[.!\s]*$",
|
| 13 |
+
r"^\s*\?+\s*$",
|
| 14 |
+
r"^\s*i see\.?\s*$",
|
| 15 |
+
r"^\s*you'?re welcome\.?\s*$", # 增加一些简单回应
|
| 16 |
+
r"^\s*okay then\.?\s*$",
|
| 17 |
+
]
|
| 18 |
+
CORRUPTED_ENDINGS = [" user", " assistan"]
|
| 19 |
+
MAX_QUERY_URLS = 0
|
| 20 |
+
MAX_QUERY_NEWLINES = 3
|
| 21 |
+
MIN_DIALOGUE_TURNS = 2 # 对 messages 列表的长度要求
|
| 22 |
+
|
| 23 |
+
# --- 新增:代码和 TTS 不友好内容过滤参数 ---
|
| 24 |
+
FILTER_CODE_KEYWORDS = True # 是否过滤包含常见代码关键字的查询
|
| 25 |
+
CODE_KEYWORDS_PATTERN = r"\b(def|class|import|function|const|let|var|public|private|static|void|main|int|float|str|bool|return|yield|async|await|try|except|finally|if|else|for|while|switch|case|break|continue|lambda|map|filter|reduce|numpy|pandas|torch|tensorflow|react|angular|vue|console\.log|System\.out\.println)\b" # 常见编程关键字 (可扩展)
|
| 26 |
+
|
| 27 |
+
FILTER_INLINE_CODE = True # 是否过滤包含 Markdown 行内代码 `...` 的查询
|
| 28 |
+
INLINE_CODE_PATTERN = r"`[^`]+`"
|
| 29 |
+
|
| 30 |
+
FILTER_MARKDOWN_TABLE_SEP = True # 是否过滤包含 Markdown 表格分隔符 `|---|`
|
| 31 |
+
MARKDOWN_TABLE_SEP_PATTERN = r"\|-+\|"
|
| 32 |
+
|
| 33 |
+
FILTER_EXCESSIVE_SPECIAL_CHARS = True # 是否过滤特殊字符比例过高的查询
|
| 34 |
+
MAX_SPECIAL_CHAR_RATIO = 0.25 # 特殊字符(非字母、数字、空格)允许的最大比例
|
| 35 |
+
|
| 36 |
+
FILTER_LONG_STRINGS_NO_SPACE = True # 是否过滤包含过长无空格字符串的查询
|
| 37 |
+
MAX_NO_SPACE_STRING_LEN = 50 # 无空格字符串的最大允许长度
|
| 38 |
+
|
| 39 |
+
QUERY_FORBIDDEN_PATTERNS = [
|
| 40 |
+
r"```", # 代码块标记 (已有)
|
| 41 |
+
# r"\|.*\|.*\|", # 简单的表格行检测 (可能过于宽泛,用下面的分隔符检测可能更好)
|
| 42 |
+
# 新增模式会根据上面的开关动态添加
|
| 43 |
+
]
|
| 44 |
+
|
| 45 |
+
# --- 脚本主逻辑 ---
|
| 46 |
+
|
| 47 |
+
def is_potentially_garbled(text):
|
| 48 |
+
if not text or not isinstance(text, str): return True
|
| 49 |
+
for ending in CORRUPTED_ENDINGS:
|
| 50 |
+
if text.endswith(ending): return True
|
| 51 |
+
# 稍微放宽括号检查,只检查严重不平衡的情况
|
| 52 |
+
if text.count('{') > text.count('}') + 2 or text.count('[') > text.count(']') + 2: return True
|
| 53 |
+
if text.count('```') % 2 != 0: return True # 未闭合的代码块
|
| 54 |
+
return False
|
| 55 |
+
|
| 56 |
+
def is_prompt_suitable(text, turn_index): # 添加 turn_index 用于调试
|
| 57 |
+
"""检查用户提问是否符合质量和 TTS 要求"""
|
| 58 |
+
if not text or not isinstance(text, str):
|
| 59 |
+
# print(f"DEBUG: Prompt rejected (turn {turn_index}): Not text or empty")
|
| 60 |
+
return False
|
| 61 |
+
|
| 62 |
+
# --- 基本检查 ---
|
| 63 |
+
word_count = len(text.split())
|
| 64 |
+
if not (MIN_USER_QUERY_WORDS <= word_count <= MAX_USER_QUERY_WORDS):
|
| 65 |
+
# print(f"DEBUG: Prompt rejected (turn {turn_index}): Word count {word_count} out of range [{MIN_USER_QUERY_WORDS}, {MAX_USER_QUERY_WORDS}]")
|
| 66 |
+
return False
|
| 67 |
+
|
| 68 |
+
text_stripped = text.strip()
|
| 69 |
+
for pattern in SIMPLE_PROMPT_PATTERNS:
|
| 70 |
+
if re.fullmatch(pattern, text_stripped, re.IGNORECASE):
|
| 71 |
+
# print(f"DEBUG: Prompt rejected (turn {turn_index}): Matched simple pattern '{pattern}'")
|
| 72 |
+
return False
|
| 73 |
+
|
| 74 |
+
if text.count('http') > MAX_QUERY_URLS:
|
| 75 |
+
# print(f"DEBUG: Prompt rejected (turn {turn_index}): Too many URLs")
|
| 76 |
+
return False
|
| 77 |
+
if text.count('\n') > MAX_QUERY_NEWLINES:
|
| 78 |
+
# print(f"DEBUG: Prompt rejected (turn {turn_index}): Too many newlines")
|
| 79 |
+
return False
|
| 80 |
+
|
| 81 |
+
# --- 通用禁止模式检查 (包括原有的和动态添加的) ---
|
| 82 |
+
current_forbidden_patterns = list(QUERY_FORBIDDEN_PATTERNS) # 复制基础列表
|
| 83 |
+
if FILTER_MARKDOWN_TABLE_SEP:
|
| 84 |
+
current_forbidden_patterns.append(MARKDOWN_TABLE_SEP_PATTERN)
|
| 85 |
+
|
| 86 |
+
for pattern in current_forbidden_patterns:
|
| 87 |
+
# 使用 re.DOTALL 使 . 匹配换行符, re.IGNORECASE 对某些模式可能有用 (比如关键词)
|
| 88 |
+
search_flags = re.DOTALL
|
| 89 |
+
if pattern == CODE_KEYWORDS_PATTERN: # 关键词需要忽略大小写
|
| 90 |
+
search_flags |= re.IGNORECASE
|
| 91 |
+
if re.search(pattern, text, search_flags):
|
| 92 |
+
# print(f"DEBUG: Prompt rejected (turn {turn_index}): Matched forbidden pattern '{pattern}'")
|
| 93 |
+
return False
|
| 94 |
+
|
| 95 |
+
# --- 新增:特定代码和 TTS 不友好内容的检查 ---
|
| 96 |
+
|
| 97 |
+
# 1. 检查常见代码关键字
|
| 98 |
+
if FILTER_CODE_KEYWORDS and re.search(CODE_KEYWORDS_PATTERN, text, re.IGNORECASE):
|
| 99 |
+
# print(f"DEBUG: Prompt rejected (turn {turn_index}): Contains code keywords")
|
| 100 |
+
return False
|
| 101 |
+
|
| 102 |
+
# 2. 检查 Markdown 行内代码
|
| 103 |
+
if FILTER_INLINE_CODE and re.search(INLINE_CODE_PATTERN, text):
|
| 104 |
+
# print(f"DEBUG: Prompt rejected (turn {turn_index}): Contains inline code")
|
| 105 |
+
return False
|
| 106 |
+
|
| 107 |
+
# 3. 检查过长的无空格字符串 (可能为哈希、base64、代码片段等)
|
| 108 |
+
if FILTER_LONG_STRINGS_NO_SPACE:
|
| 109 |
+
# \S 匹配任何非空白字符
|
| 110 |
+
if re.search(r"\S{" + str(MAX_NO_SPACE_STRING_LEN) + r",}", text):
|
| 111 |
+
# print(f"DEBUG: Prompt rejected (turn {turn_index}): Contains long string without spaces (>{MAX_NO_SPACE_STRING_LEN})")
|
| 112 |
+
return False
|
| 113 |
+
|
| 114 |
+
# 4. 检查特殊字符(非字母、数字、空格)的比例
|
| 115 |
+
if FILTER_EXCESSIVE_SPECIAL_CHARS and len(text) > 0: # 避免除以零
|
| 116 |
+
special_chars = 0
|
| 117 |
+
total_chars = len(text)
|
| 118 |
+
for char in text:
|
| 119 |
+
# string.punctuation 包含常用标点
|
| 120 |
+
# 我们也排除字母、数字和空格,剩下的算作特殊字符
|
| 121 |
+
if not char.isalnum() and not char.isspace():
|
| 122 |
+
special_chars += 1
|
| 123 |
+
ratio = special_chars / total_chars
|
| 124 |
+
if ratio > MAX_SPECIAL_CHAR_RATIO:
|
| 125 |
+
# print(f"DEBUG: Prompt rejected (turn {turn_index}): Excessive special characters ratio ({ratio:.2f} > {MAX_SPECIAL_CHAR_RATIO})")
|
| 126 |
+
return False
|
| 127 |
+
|
| 128 |
+
# --- 最终 Garbled 检查 ---
|
| 129 |
+
if is_potentially_garbled(text):
|
| 130 |
+
# print(f"DEBUG: Prompt rejected (turn {turn_index}): Potentially garbled")
|
| 131 |
+
return False
|
| 132 |
+
|
| 133 |
+
# print(f"DEBUG: Prompt accepted (turn {turn_index})")
|
| 134 |
+
return True
|
| 135 |
+
|
| 136 |
+
# --- format_history 函数保持不变 ---
|
| 137 |
+
def format_history(history_list):
|
| 138 |
+
"""将历史消息列表格式化为文本"""
|
| 139 |
+
if not history_list:
|
| 140 |
+
return ""
|
| 141 |
+
formatted = []
|
| 142 |
+
for msg in history_list:
|
| 143 |
+
role_tag = "[USER]" if msg.get('role') == 'user' else "[ASSISTANT]"
|
| 144 |
+
content = msg.get('content', '')
|
| 145 |
+
formatted.append(f"{role_tag}\n{content}")
|
| 146 |
+
return "\n\n".join(formatted)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
# --- filter_ultrachat_dataset_v2 函数保持不变 (除了调用更新后的 is_prompt_suitable) ---
|
| 150 |
+
def filter_ultrachat_dataset_v2(dataset_name="HuggingFaceH4/ultrachat_200k", split="train_sft"):
|
| 151 |
+
"""
|
| 152 |
+
加载并过滤 UltraChat 数据集 (根据截图修正结构访问)。
|
| 153 |
+
使用更新后的 is_prompt_suitable 进行过滤。
|
| 154 |
+
"""
|
| 155 |
+
print(f"加载数据集: {dataset_name}, split: {split}...")
|
| 156 |
+
try:
|
| 157 |
+
dataset = load_dataset(dataset_name, split=split)
|
| 158 |
+
print(f"'{split}' split 加载完成。")
|
| 159 |
+
except Exception as e:
|
| 160 |
+
print(f"错误:无法加载数据集 {dataset_name} 的 '{split}' split。")
|
| 161 |
+
print(f"错误详情: {e}")
|
| 162 |
+
return []
|
| 163 |
+
|
| 164 |
+
filtered_samples = []
|
| 165 |
+
processed_dialogues = 0
|
| 166 |
+
extracted_samples = 0
|
| 167 |
+
skipped_garbled_dialogue = 0
|
| 168 |
+
skipped_short_dialogue = 0
|
| 169 |
+
skipped_bad_format = 0
|
| 170 |
+
|
| 171 |
+
print(f"\n开始处理 '{split}' split 中的对话...")
|
| 172 |
+
for dialogue in tqdm(dataset, desc="处理对话"):
|
| 173 |
+
processed_dialogues += 1
|
| 174 |
+
|
| 175 |
+
messages = dialogue.get("messages")
|
| 176 |
+
prompt_id = dialogue.get("prompt_id")
|
| 177 |
+
initial_prompt = dialogue.get("prompt")
|
| 178 |
+
|
| 179 |
+
if not prompt_id: continue
|
| 180 |
+
if not messages or not isinstance(messages, list):
|
| 181 |
+
skipped_bad_format += 1
|
| 182 |
+
continue
|
| 183 |
+
if len(messages) < MIN_DIALOGUE_TURNS:
|
| 184 |
+
skipped_short_dialogue += 1
|
| 185 |
+
continue
|
| 186 |
+
|
| 187 |
+
dialogue_seems_garbled = False
|
| 188 |
+
for msg in messages:
|
| 189 |
+
content = msg.get("content")
|
| 190 |
+
# 对话级损坏检查现在仅基于 is_potentially_garbled
|
| 191 |
+
if is_potentially_garbled(content):
|
| 192 |
+
dialogue_seems_garbled = True
|
| 193 |
+
break
|
| 194 |
+
if dialogue_seems_garbled:
|
| 195 |
+
skipped_garbled_dialogue += 1
|
| 196 |
+
continue
|
| 197 |
+
|
| 198 |
+
current_history_list = []
|
| 199 |
+
for i, message in enumerate(messages):
|
| 200 |
+
role = message.get("role")
|
| 201 |
+
content = message.get("content", "").strip()
|
| 202 |
+
|
| 203 |
+
if not role or not content:
|
| 204 |
+
continue
|
| 205 |
+
|
| 206 |
+
if role == "user":
|
| 207 |
+
# 调用更新后的过滤函数
|
| 208 |
+
if is_prompt_suitable(content, i):
|
| 209 |
+
history_text = format_history(current_history_list)
|
| 210 |
+
filtered_samples.append({
|
| 211 |
+
"dialogue_id": prompt_id,
|
| 212 |
+
"turn_index": i,
|
| 213 |
+
"query": content,
|
| 214 |
+
"history": history_text
|
| 215 |
+
})
|
| 216 |
+
extracted_samples += 1
|
| 217 |
+
# else: # 取消注释内部打印以查看拒绝原因
|
| 218 |
+
# pass
|
| 219 |
+
|
| 220 |
+
current_history_list.append({"role": role, "content": content})
|
| 221 |
+
|
| 222 |
+
print(f"\n过滤完成。")
|
| 223 |
+
print(f"处理对话数: {processed_dialogues}")
|
| 224 |
+
print(f"因格式错误跳过: {skipped_bad_format}")
|
| 225 |
+
print(f"因 messages 列表过短 (<{MIN_DIALOGUE_TURNS} turns) 跳过: {skipped_short_dialogue}")
|
| 226 |
+
print(f"因疑似损坏跳过的对话数 (基于 is_potentially_garbled): {skipped_garbled_dialogue}")
|
| 227 |
+
print(f"提取出的有效用户提问样本数: {extracted_samples}")
|
| 228 |
+
return filtered_samples
|
| 229 |
+
|
| 230 |
+
# --- 主程序 (保持不变,但调用 V2 函数) ---
|
| 231 |
+
if __name__ == "__main__":
|
| 232 |
+
# 调用修正后的过滤函数
|
| 233 |
+
filtered_data_list = filter_ultrachat_dataset_v2(dataset_name="HuggingFaceH4/ultrachat_200k", split="train_sft")
|
| 234 |
+
|
| 235 |
+
if filtered_data_list:
|
| 236 |
+
filtered_dataset = Dataset.from_list(filtered_data_list)
|
| 237 |
+
# 更新输出目录名以反映新的过滤规则
|
| 238 |
+
output_path = "./ultrachat_filtered_for_tts_preference_v3_nocode"
|
| 239 |
+
print(f"\n正在保存过滤后的数据集到: {output_path}")
|
| 240 |
+
os.makedirs(output_path, exist_ok=True)
|
| 241 |
+
filtered_dataset.save_to_disk(output_path)
|
| 242 |
+
print("数据集保存完成.")
|
| 243 |
+
|
| 244 |
+
print("\n部分样本预览 (从保存的 Dataset 加载):")
|
| 245 |
+
try:
|
| 246 |
+
loaded_dataset = Dataset.load_from_disk(output_path)
|
| 247 |
+
for i in range(min(5, len(loaded_dataset))):
|
| 248 |
+
sample = loaded_dataset[i]
|
| 249 |
+
print(f"--- 样本 {i+1} (Dialogue ID: {sample['dialogue_id']}, Turn: {sample['turn_index']}) ---")
|
| 250 |
+
print(f"History (last 500 chars):\n...{sample['history'][-500:]}")
|
| 251 |
+
print(f"\nQuery: {sample['query']}")
|
| 252 |
+
print("-" * 20)
|
| 253 |
+
except Exception as e:
|
| 254 |
+
print(f"加载预览样本时出错: {e}")
|
| 255 |
+
|
| 256 |
+
else:
|
| 257 |
+
print("\n没有找到符合条件的样本。可能原因:")
|
| 258 |
+
print("1. 过滤参数过于严格 (检查 MIN/MAX word counts, SIMPLE_PROMPT_PATTERNS, 新增的代码/TTS过滤参数等)。")
|
| 259 |
+
print("2. `is_potentially_garbled` 规则误判。")
|
| 260 |
+
print("3. 数据集本身在此 split 中没有符合条件的对话。")
|
| 261 |
+
print("4. (请检查脚本输出的跳过计数,看是哪个阶段跳过了大量样本)")
|
r1-a/dataset/ultrachat_tts.py
ADDED
|
@@ -0,0 +1,382 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --- SET CUDA DEVICE ---
|
| 2 |
+
# Method 1: Set environment variable BEFORE importing torch/cosyvoice
|
| 3 |
+
# This makes only GPU 1 visible to the script, appearing as 'cuda:0' internally.
|
| 4 |
+
import os
|
| 5 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
|
| 6 |
+
# --- End CUDA Device Setting ---
|
| 7 |
+
|
| 8 |
+
import random
|
| 9 |
+
import torch
|
| 10 |
+
import torchaudio
|
| 11 |
+
# Make sure necessary types are imported
|
| 12 |
+
from datasets import load_dataset, Dataset, load_from_disk, Features, Value
|
| 13 |
+
import sys
|
| 14 |
+
from tqdm import tqdm
|
| 15 |
+
import time
|
| 16 |
+
import shutil # Added for potentially removing old dataset save dirs
|
| 17 |
+
|
| 18 |
+
# Check if the specified GPU is available after setting the environment variable
|
| 19 |
+
if not torch.cuda.is_available():
|
| 20 |
+
print("ERROR: CUDA is not available after setting CUDA_VISIBLE_DEVICES='1'. Check your PyTorch installation, GPU drivers, and that GPU 1 exists and is functional.")
|
| 21 |
+
# Force exit if the intended GPU is not found
|
| 22 |
+
sys.exit(1)
|
| 23 |
+
else:
|
| 24 |
+
# Since CUDA_VISIBLE_DEVICES is set to '1', the first *visible* device is cuda:0
|
| 25 |
+
effective_device = torch.device("cuda:0")
|
| 26 |
+
try:
|
| 27 |
+
print(f"CUDA device visible to PyTorch: {torch.cuda.get_device_name(0)}") # Should show the name of GPU 1
|
| 28 |
+
print(f"Script will effectively run on: {effective_device}")
|
| 29 |
+
# Perform a small check to ensure the device is usable
|
| 30 |
+
_ = torch.tensor([1.0]).to(effective_device)
|
| 31 |
+
print("Device check successful.")
|
| 32 |
+
except Exception as e:
|
| 33 |
+
print(f"ERROR: Failed CUDA device check for visible device 'cuda:0' (original GPU 1): {e}")
|
| 34 |
+
sys.exit(1)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# Ensure CosyVoice path is correct
|
| 38 |
+
COSYVOICE_PATH = '/root/autodl-tmp/CosyVoice' # Make sure this path is correct
|
| 39 |
+
if not os.path.isdir(COSYVOICE_PATH):
|
| 40 |
+
print(f"ERROR: CosyVoice path not found: {COSYVOICE_PATH}")
|
| 41 |
+
sys.exit(1)
|
| 42 |
+
sys.path.append(COSYVOICE_PATH)
|
| 43 |
+
|
| 44 |
+
# Import CosyVoice *after* setting the environment variable
|
| 45 |
+
try:
|
| 46 |
+
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
|
| 47 |
+
from cosyvoice.utils.file_utils import load_wav
|
| 48 |
+
print("CosyVoice imported successfully.")
|
| 49 |
+
except ImportError as e:
|
| 50 |
+
print(f"Error importing CosyVoice: {e}")
|
| 51 |
+
print(f"Please ensure the path '{COSYVOICE_PATH}' is correct and the library is installed within that directory.")
|
| 52 |
+
sys.exit(1)
|
| 53 |
+
except Exception as e:
|
| 54 |
+
print(f"An unexpected error occurred during CosyVoice import: {e}")
|
| 55 |
+
sys.exit(1)
|
| 56 |
+
|
| 57 |
+
# ------------------------
|
| 58 |
+
# 配置参数 (MODIFIED FOR Selected UltraChat DATASET)
|
| 59 |
+
# ------------------------
|
| 60 |
+
COMMON_VOICE_LANGUAGE = "en" # Language for prompts
|
| 61 |
+
|
| 62 |
+
# --- !! MODIFIED !! ---
|
| 63 |
+
# Input: Path to the SELECTED UltraChat dataset (Top 20%) from the previous script
|
| 64 |
+
INPUT_DATASET_PATH = "./ultrachat_final_top20_percent/ultrachat_top20_percent_by_complexity"
|
| 65 |
+
# Output: Directory to save new audio files and the final dataset object for THIS specific dataset
|
| 66 |
+
OUTPUT_DATASET_PATH = './ultrachat_top20_percent_with_query_audio' # New distinct output path
|
| 67 |
+
# --- End MODIFIED ---
|
| 68 |
+
|
| 69 |
+
SAMPLE_RATE = 16000 # Target sample rate for TTS output (should match CosyVoice default)
|
| 70 |
+
MAX_TTS_RETRIES = 3
|
| 71 |
+
RETRY_DELAY_SECONDS = 3
|
| 72 |
+
|
| 73 |
+
# ------------------------
|
| 74 |
+
# 辅助函数 (GPU handling and core TTS logic - UNCHANGED as requested)
|
| 75 |
+
# ------------------------
|
| 76 |
+
def get_random_prompt(common_voice_dataset, sample_rate=SAMPLE_RATE):
|
| 77 |
+
"""
|
| 78 |
+
从 VoxPopuli 数据集中随机抽取一条语音及对应文本作为 prompt。
|
| 79 |
+
(Logic remains unchanged from previous TTS script)
|
| 80 |
+
"""
|
| 81 |
+
idx = random.randint(0, len(common_voice_dataset) - 1)
|
| 82 |
+
try:
|
| 83 |
+
sample = common_voice_dataset.select([idx]).with_format('numpy')[0]
|
| 84 |
+
audio = sample['audio']
|
| 85 |
+
waveform = torch.tensor(audio['array'], dtype=torch.float32) # CPU
|
| 86 |
+
sr = audio['sampling_rate']
|
| 87 |
+
if sr != sample_rate:
|
| 88 |
+
if waveform.dim() > 1: waveform = waveform.mean(dim=0)
|
| 89 |
+
if waveform.dim() != 1: return get_random_prompt(common_voice_dataset, sample_rate)
|
| 90 |
+
resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate)
|
| 91 |
+
waveform = resampler(waveform)
|
| 92 |
+
if waveform.dim() == 1: waveform = waveform.unsqueeze(0)
|
| 93 |
+
elif waveform.dim() > 2: return get_random_prompt(common_voice_dataset, sample_rate)
|
| 94 |
+
raw_text = sample.get('raw_text', '')
|
| 95 |
+
if waveform.numel() == 0 or not raw_text or not raw_text.strip():
|
| 96 |
+
return get_random_prompt(common_voice_dataset, sample_rate)
|
| 97 |
+
return waveform, raw_text # Return CPU tensor
|
| 98 |
+
except Exception as e:
|
| 99 |
+
time.sleep(0.1)
|
| 100 |
+
return get_random_prompt(common_voice_dataset, sample_rate)
|
| 101 |
+
|
| 102 |
+
def text_to_audio(text_to_convert, cosyvoice, common_voice_dataset, stream=False, max_retries=MAX_TTS_RETRIES):
|
| 103 |
+
"""
|
| 104 |
+
利用 CosyVoice2 模型将输入文本转换为语音,采用随机 prompt 进行零样本推理。
|
| 105 |
+
Includes retry logic on failure. Assumes cosyvoice runs on the configured device.
|
| 106 |
+
(Logic remains unchanged from previous TTS script)
|
| 107 |
+
"""
|
| 108 |
+
last_exception = None
|
| 109 |
+
prompt_speech, prompt_text = None, "N/A"
|
| 110 |
+
for attempt in range(max_retries):
|
| 111 |
+
try:
|
| 112 |
+
prompt_speech, prompt_text = get_random_prompt(common_voice_dataset, sample_rate=SAMPLE_RATE) # CPU tensor
|
| 113 |
+
all_speech = []
|
| 114 |
+
inference_generator = cosyvoice.inference_zero_shot(
|
| 115 |
+
text_to_convert, prompt_text, prompt_speech, stream=stream, text_frontend=False
|
| 116 |
+
)
|
| 117 |
+
for i, chunk in enumerate(inference_generator): # Chunks on GPU
|
| 118 |
+
if chunk is None: continue
|
| 119 |
+
if 'tts_speech' in chunk and chunk['tts_speech'] is not None and chunk['tts_speech'].numel() > 0:
|
| 120 |
+
gpu_chunk = chunk['tts_speech'].to(effective_device)
|
| 121 |
+
all_speech.append(gpu_chunk)
|
| 122 |
+
if not all_speech:
|
| 123 |
+
if torch.cuda.is_available(): torch.cuda.empty_cache()
|
| 124 |
+
raise ValueError("TTS inference finished but produced no valid audio chunks.")
|
| 125 |
+
combined_speech = torch.cat(all_speech, dim=-1) # GPU tensor
|
| 126 |
+
sample_rate_val = cosyvoice.sample_rate
|
| 127 |
+
if torch.max(torch.abs(combined_speech)) < 0.001:
|
| 128 |
+
raise ValueError("Generated audio is silent")
|
| 129 |
+
return {'audio_tensor': combined_speech, 'sample_rate': sample_rate_val} # Return GPU tensor
|
| 130 |
+
except Exception as e:
|
| 131 |
+
last_exception = e
|
| 132 |
+
print(f"Error converting text to audio on attempt {attempt + 1}/{max_retries}: {e}")
|
| 133 |
+
print(f" Text: '{text_to_convert[:100]}...'")
|
| 134 |
+
# print(f" Prompt Text Used: '{prompt_text[:100]}...'") # Reduce log noise
|
| 135 |
+
if torch.cuda.is_available(): torch.cuda.empty_cache()
|
| 136 |
+
if attempt < max_retries - 1:
|
| 137 |
+
print(f" Retrying with a different prompt in {RETRY_DELAY_SECONDS}s...")
|
| 138 |
+
time.sleep(RETRY_DELAY_SECONDS)
|
| 139 |
+
else:
|
| 140 |
+
print(f" All {max_retries} TTS attempts failed.")
|
| 141 |
+
return None
|
| 142 |
+
|
| 143 |
+
# --- PROCESS EXAMPLE (Targets 'query' field) ---
|
| 144 |
+
def process_example(example, cosyvoice, common_voice_dataset, sample_rate=SAMPLE_RATE):
|
| 145 |
+
"""
|
| 146 |
+
针对从磁盘加载的 *Selected Top 20% UltraChat* 数据集中的单个样本进行 TTS 处理。
|
| 147 |
+
Processes the example['query'] field.
|
| 148 |
+
"""
|
| 149 |
+
text_to_convert = example.get('query')
|
| 150 |
+
# Get identifiers for logging, if they exist in this dataset version
|
| 151 |
+
dialogue_id = example.get('dialogue_id', 'N/A')
|
| 152 |
+
turn_index = example.get('turn_index', 'N/A') # May not be present if not carried over
|
| 153 |
+
|
| 154 |
+
if not text_to_convert or not isinstance(text_to_convert, str) or not text_to_convert.strip():
|
| 155 |
+
print(f"Warning: Skipping example (ID: {dialogue_id}, Turn: {turn_index}) due to missing or empty 'query' field.")
|
| 156 |
+
return None
|
| 157 |
+
|
| 158 |
+
# Call the unchanged text_to_audio function
|
| 159 |
+
audio_result = text_to_audio(text_to_convert, cosyvoice, common_voice_dataset, stream=False)
|
| 160 |
+
|
| 161 |
+
if audio_result is not None:
|
| 162 |
+
audio_tensor = audio_result['audio_tensor'] # Still on GPU here
|
| 163 |
+
if audio_tensor is None or audio_tensor.numel() == 0:
|
| 164 |
+
print(f"Warning: TTS process returned empty tensor for query (ID: {dialogue_id}, Turn: {turn_index}): '{text_to_convert[:60]}...'")
|
| 165 |
+
return None
|
| 166 |
+
if audio_tensor.dim() == 1: audio_tensor = audio_tensor.unsqueeze(0)
|
| 167 |
+
elif audio_tensor.dim() > 2:
|
| 168 |
+
print(f"Warning: Generated audio tensor unexpected shape {audio_tensor.shape} (ID: {dialogue_id}, Turn: {turn_index}). Flattening.")
|
| 169 |
+
audio_tensor = audio_tensor.view(1, -1) # Flatten to [1, T]
|
| 170 |
+
if audio_tensor.numel() == 0:
|
| 171 |
+
print(f"Warning: Generated audio tensor became empty after reshape for query (ID: {dialogue_id}, Turn: {turn_index}): '{text_to_convert[:60]}...'")
|
| 172 |
+
return None
|
| 173 |
+
return {
|
| 174 |
+
'audio_tensor': audio_tensor, # Return GPU tensor
|
| 175 |
+
'sample_rate': audio_result['sample_rate']
|
| 176 |
+
}
|
| 177 |
+
else:
|
| 178 |
+
return None # Errors logged within text_to_audio
|
| 179 |
+
|
| 180 |
+
# ------------------------
|
| 181 |
+
# 数据加载与模型初始化 (Model and Prompt Dataset Loading Unchanged)
|
| 182 |
+
# ------------------------
|
| 183 |
+
print("Loading VoxPopuli (as Common Voice) dataset for prompts...")
|
| 184 |
+
try:
|
| 185 |
+
common_voice = load_dataset("facebook/voxpopuli", COMMON_VOICE_LANGUAGE, split='train', trust_remote_code=True)
|
| 186 |
+
common_voice = common_voice.filter(lambda x: x['audio'] is not None and x['audio']['array'] is not None and x['raw_text'] is not None and len(x['raw_text'].strip()) > 5 and len(x['audio']['array']) > SAMPLE_RATE * 0.5)
|
| 187 |
+
print(f"Loaded and filtered VoxPopuli '{COMMON_VOICE_LANGUAGE}' samples: {len(common_voice)}")
|
| 188 |
+
if len(common_voice) == 0: raise ValueError(f"VoxPopuli '{COMMON_VOICE_LANGUAGE}' loaded but no valid samples after filtering.")
|
| 189 |
+
except Exception as e:
|
| 190 |
+
print(f"Error loading or filtering VoxPopuli dataset: {e}")
|
| 191 |
+
sys.exit(1)
|
| 192 |
+
|
| 193 |
+
print("Initializing CosyVoice2 model...")
|
| 194 |
+
try:
|
| 195 |
+
# CosyVoice initialization remains the same
|
| 196 |
+
cosyvoice_model_path = os.path.join(COSYVOICE_PATH, 'pretrained_models/CosyVoice2-0.5B')
|
| 197 |
+
if not os.path.isdir(cosyvoice_model_path): raise FileNotFoundError(f"CosyVoice pretrained model directory not found: {cosyvoice_model_path}")
|
| 198 |
+
cosyvoice = CosyVoice2(
|
| 199 |
+
cosyvoice_model_path, load_jit=True, load_trt=False, fp16=False
|
| 200 |
+
)
|
| 201 |
+
print(f"CosyVoice model initialized. Target device: {effective_device}")
|
| 202 |
+
except Exception as e:
|
| 203 |
+
print(f"Error initializing CosyVoice2 model: {e}")
|
| 204 |
+
if isinstance(e, RuntimeError) and 'CUDA' in str(e): print("CUDA initialization error? Check GPU 1 status/memory.")
|
| 205 |
+
sys.exit(1)
|
| 206 |
+
|
| 207 |
+
# --- !! MODIFIED Selected UltraChat Dataset Loading !! ---
|
| 208 |
+
print(f"\nLoading the target Selected UltraChat (Top 20%) dataset from disk: {INPUT_DATASET_PATH}")
|
| 209 |
+
if not os.path.exists(INPUT_DATASET_PATH):
|
| 210 |
+
print(f"Error: Input dataset directory not found at '{INPUT_DATASET_PATH}'.")
|
| 211 |
+
print("Please ensure the UltraChat Selection script ran successfully and produced the dataset at this location.")
|
| 212 |
+
sys.exit(1)
|
| 213 |
+
|
| 214 |
+
try:
|
| 215 |
+
input_dataset = load_from_disk(INPUT_DATASET_PATH)
|
| 216 |
+
|
| 217 |
+
print(f"Successfully loaded Selected UltraChat dataset with {len(input_dataset)} examples.")
|
| 218 |
+
if len(input_dataset) == 0:
|
| 219 |
+
print("Error: The loaded dataset is empty. Cannot proceed.")
|
| 220 |
+
sys.exit(1)
|
| 221 |
+
# Store original features to reconstruct the final dataset correctly
|
| 222 |
+
original_features = input_dataset.features
|
| 223 |
+
print(f"Original features: {original_features}")
|
| 224 |
+
# Check for 'query' column existence (essential for TTS)
|
| 225 |
+
if 'query' not in original_features:
|
| 226 |
+
print(f"Error: The loaded dataset from '{INPUT_DATASET_PATH}' does not contain the required 'query' column.")
|
| 227 |
+
sys.exit(1)
|
| 228 |
+
|
| 229 |
+
except Exception as e:
|
| 230 |
+
print(f"Error loading dataset from '{INPUT_DATASET_PATH}': {e}")
|
| 231 |
+
sys.exit(1)
|
| 232 |
+
# --- End MODIFIED Dataset Loading ---
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
# --- Create output directories ---
|
| 236 |
+
os.makedirs(OUTPUT_DATASET_PATH, exist_ok=True)
|
| 237 |
+
audio_output_dir = os.path.join(OUTPUT_DATASET_PATH, "audio_files")
|
| 238 |
+
os.makedirs(audio_output_dir, exist_ok=True)
|
| 239 |
+
print(f"Audio files will be saved in: {audio_output_dir}")
|
| 240 |
+
progress_file = os.path.join(OUTPUT_DATASET_PATH, "progress.txt")
|
| 241 |
+
print(f"Progress will be tracked in: {progress_file}")
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
# ------------------------
|
| 245 |
+
# 主处理循环 (MODIFIED FOR SINGLE Selected UltraChat DATASET)
|
| 246 |
+
# ------------------------
|
| 247 |
+
# --- !! MODIFIED: Update log message !! ---
|
| 248 |
+
print(f"\nStarting TTS processing for {len(input_dataset)} Selected UltraChat (Top 20%) samples...")
|
| 249 |
+
|
| 250 |
+
start_index = 0
|
| 251 |
+
# Read progress file to resume if necessary
|
| 252 |
+
if os.path.exists(progress_file):
|
| 253 |
+
try:
|
| 254 |
+
with open(progress_file, "r") as f:
|
| 255 |
+
content = f.read().strip()
|
| 256 |
+
if content: start_index = int(content)
|
| 257 |
+
print(f"Resuming TTS processing from sample index {start_index}")
|
| 258 |
+
except Exception as e:
|
| 259 |
+
print(f"Error reading progress file '{progress_file}': {e}. Starting TTS from index 0.")
|
| 260 |
+
start_index = 0
|
| 261 |
+
|
| 262 |
+
# List to hold dictionaries for the final dataset
|
| 263 |
+
final_samples = []
|
| 264 |
+
|
| 265 |
+
# --- Main Loop ---
|
| 266 |
+
# --- !! MODIFIED: Update progress bar description !! ---
|
| 267 |
+
pbar = tqdm(range(start_index, len(input_dataset)), desc=f"TTS on Selected UltraChat 'query'", initial=start_index, total=len(input_dataset))
|
| 268 |
+
for i in pbar:
|
| 269 |
+
sample = input_dataset[i] # Get sample dictionary (on CPU)
|
| 270 |
+
|
| 271 |
+
# Define unique output WAV path using the index
|
| 272 |
+
output_wav_filename = f"query_{i}.wav"
|
| 273 |
+
output_wav_path = os.path.join(audio_output_dir, output_wav_filename)
|
| 274 |
+
|
| 275 |
+
# --- Check if audio file already exists ---
|
| 276 |
+
if os.path.exists(output_wav_path):
|
| 277 |
+
sample_dict = dict(sample)
|
| 278 |
+
sample_dict["query_audio_filepath"] = output_wav_path # Add path field
|
| 279 |
+
final_samples.append(sample_dict)
|
| 280 |
+
with open(progress_file, "w") as f: f.write(str(i + 1))
|
| 281 |
+
continue # Skip TTS
|
| 282 |
+
|
| 283 |
+
# --- Perform TTS on the target device ---
|
| 284 |
+
result = process_example(sample, cosyvoice, common_voice, sample_rate=SAMPLE_RATE)
|
| 285 |
+
|
| 286 |
+
if result is not None and 'audio_tensor' in result and result['audio_tensor'] is not None:
|
| 287 |
+
audio_tensor = result['audio_tensor'] # GPU tensor
|
| 288 |
+
sample_rate_val = result['sample_rate']
|
| 289 |
+
try:
|
| 290 |
+
# Move tensor to CPU before saving
|
| 291 |
+
audio_tensor_save = audio_tensor.detach().cpu().to(torch.float32)
|
| 292 |
+
if audio_tensor_save.dim() == 1: audio_tensor_save = audio_tensor_save.unsqueeze(0)
|
| 293 |
+
elif audio_tensor_save.dim() > 2: audio_tensor_save = audio_tensor_save.view(1, -1)
|
| 294 |
+
|
| 295 |
+
torchaudio.save(output_wav_path, audio_tensor_save, sample_rate_val)
|
| 296 |
+
|
| 297 |
+
# Create dict for the final dataset
|
| 298 |
+
sample_dict = dict(sample)
|
| 299 |
+
sample_dict["query_audio_filepath"] = output_wav_path # Add path field
|
| 300 |
+
final_samples.append(sample_dict)
|
| 301 |
+
|
| 302 |
+
del audio_tensor # Delete GPU tensor
|
| 303 |
+
|
| 304 |
+
except Exception as e:
|
| 305 |
+
# Log error with identifiers if available
|
| 306 |
+
dialogue_id = sample.get('dialogue_id', 'N/A')
|
| 307 |
+
turn_index = sample.get('turn_index', 'N/A')
|
| 308 |
+
print(f"Failed to save wav for sample {i} (ID: {dialogue_id}, Turn: {turn_index}) at {output_wav_path}: {e}")
|
| 309 |
+
if os.path.exists(output_wav_path):
|
| 310 |
+
try: os.remove(output_wav_path)
|
| 311 |
+
except OSError: pass
|
| 312 |
+
if torch.cuda.is_available(): torch.cuda.empty_cache()
|
| 313 |
+
else:
|
| 314 |
+
# Failure logged in process_example/text_to_audio
|
| 315 |
+
if torch.cuda.is_available(): torch.cuda.empty_cache()
|
| 316 |
+
|
| 317 |
+
# --- Update progress file ---
|
| 318 |
+
with open(progress_file, "w") as f: f.write(str(i + 1))
|
| 319 |
+
|
| 320 |
+
# --- Optional: Periodic cache clearing ---
|
| 321 |
+
if i > 0 and i % 50 == 0:
|
| 322 |
+
if torch.cuda.is_available(): torch.cuda.empty_cache()
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
# --- Final cache clear after finishing the loop ---
|
| 326 |
+
if torch.cuda.is_available():
|
| 327 |
+
print("Clearing final CUDA cache...")
|
| 328 |
+
torch.cuda.empty_cache()
|
| 329 |
+
|
| 330 |
+
# ------------------------
|
| 331 |
+
# 保存最终数据集 (MODIFIED FOR Selected UltraChat)
|
| 332 |
+
# ------------------------
|
| 333 |
+
print("\nTTS processing loop finished.")
|
| 334 |
+
if final_samples:
|
| 335 |
+
# --- !! MODIFIED: Update log message !! ---
|
| 336 |
+
print(f"Successfully processed (or skipped existing) {len(final_samples)} Selected UltraChat (Top 20%) samples.")
|
| 337 |
+
|
| 338 |
+
# --- Define features for the new dataset ---
|
| 339 |
+
new_features_dict = original_features.copy()
|
| 340 |
+
new_column_name = 'query_audio_filepath' # Name of the new column
|
| 341 |
+
if new_column_name in new_features_dict:
|
| 342 |
+
print(f"Warning: Feature '{new_column_name}' already exists in original features. Overwriting.")
|
| 343 |
+
new_features_dict[new_column_name] = Value('string') # Add the new column definition
|
| 344 |
+
try:
|
| 345 |
+
new_features = Features(new_features_dict)
|
| 346 |
+
print(f"Defined new features for saving: {new_features}")
|
| 347 |
+
|
| 348 |
+
# --- Create the final Dataset object ---
|
| 349 |
+
print("Creating final Dataset object from processed samples...")
|
| 350 |
+
final_dataset_obj = Dataset.from_list(final_samples, features=new_features)
|
| 351 |
+
|
| 352 |
+
# --- Define path to save the final dataset metadata object ---
|
| 353 |
+
final_dataset_save_path = os.path.join(OUTPUT_DATASET_PATH, "processed_dataset_with_audio")
|
| 354 |
+
# --- !! MODIFIED: Update log message !! ---
|
| 355 |
+
print(f"Saving final Selected UltraChat (Top 20%) dataset metadata (with audio paths) to: {final_dataset_save_path}...")
|
| 356 |
+
|
| 357 |
+
# Ensure the target directory exists and is empty before saving
|
| 358 |
+
if os.path.exists(final_dataset_save_path):
|
| 359 |
+
print(f"Removing existing directory before saving: {final_dataset_save_path}")
|
| 360 |
+
shutil.rmtree(final_dataset_save_path)
|
| 361 |
+
|
| 362 |
+
final_dataset_obj.save_to_disk(final_dataset_save_path)
|
| 363 |
+
print(f"Final dataset object saved successfully.")
|
| 364 |
+
|
| 365 |
+
except Exception as e:
|
| 366 |
+
print(f"\n!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
|
| 367 |
+
print(f"Error during final dataset creation or saving: {e}")
|
| 368 |
+
print(f"Audio files might be saved in '{audio_output_dir}', but the final dataset object could not be created/saved.")
|
| 369 |
+
print(f"Check the features and the content of 'final_samples'.")
|
| 370 |
+
print(f"!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
|
| 371 |
+
|
| 372 |
+
else:
|
| 373 |
+
print("Processing finished, but no samples were successfully processed or had existing audio files.")
|
| 374 |
+
print(f"Check logs for TTS errors. Audio files directory: '{audio_output_dir}'.")
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
print("\n" + "="*60)
|
| 378 |
+
# --- !! MODIFIED: Update final log messages !! ---
|
| 379 |
+
print(f"Script finished for Selected UltraChat (Top 20%) dataset.")
|
| 380 |
+
print(f"Generated audio files are located in: '{audio_output_dir}'")
|
| 381 |
+
print(f"The final dataset (metadata + audio file paths) is saved at: '{os.path.join(OUTPUT_DATASET_PATH, 'processed_dataset_with_audio')}' (if saving was successful)")
|
| 382 |
+
print("="*60)
|
r1-a/prompt_only_examine.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datasets import load_from_disk
|
| 2 |
+
import torchaudio
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
# IMPORTANT: When you load and use this dataset, your CWD should have the same
|
| 6 |
+
# relationship to the audio files as it did when the dataset was created,
|
| 7 |
+
# OR you need to manually resolve the paths.
|
| 8 |
+
|
| 9 |
+
# Path where the HF dataset was saved
|
| 10 |
+
dataset_path = '/root/autodl-tmp/audio-r1/r1-a/dataset/prompt_only_fully_merged_with_audio/final_hf_dataset_relative_paths_v4'
|
| 11 |
+
ds = load_from_disk(dataset_path)
|
| 12 |
+
breakpoint() # Use this to inspect the dataset structure if needed
|
| 13 |
+
# Get a relative path string from the dataset
|
| 14 |
+
relative_path_from_dataset = ds[0]['question_audio_relative_path'] # Or your field name
|
| 15 |
+
print(f"Relative path from dataset: {relative_path_from_dataset}")
|
| 16 |
+
|
| 17 |
+
# To load the audio, this relative path needs to resolve correctly from your *current* CWD
|
| 18 |
+
# Option 1: If your CWD is correct
|
| 19 |
+
# current_cwd_when_loading = os.getcwd()
|
| 20 |
+
# print(f"Current CWD when loading: {current_cwd_when_loading}")
|
| 21 |
+
# full_path_to_audio = os.path.abspath(relative_path_from_dataset) # os.path.abspath resolves based on CWD
|
| 22 |
+
|
| 23 |
+
# Option 2: If you know the dataset's "root" directory relative to which paths were made
|
| 24 |
+
# This is safer if you move the dataset and audio files together.
|
| 25 |
+
# Assume the dataset was created when CWD was '/root/autodl-tmp/audio-r1/r1-a/dataset/'
|
| 26 |
+
# And now you are running this loading script from somewhere else, but you know that 'root'.
|
| 27 |
+
# For example, if your audio files are now located such that the relative path still makes sense
|
| 28 |
+
# if prepended by a new base_dir.
|
| 29 |
+
|
| 30 |
+
# Example: If you know the original CWD when the dataset was created,
|
| 31 |
+
# and you want to reconstruct the absolute path assuming the audio files haven't moved
|
| 32 |
+
# This is generally what os.path.join(original_cwd, relative_path) would give if files are static
|
| 33 |
+
# However, if you've moved the dataset AND audio files together, keeping their relative structure,
|
| 34 |
+
# then the relative_path_from_dataset should resolve correctly if your CWD is the new "root" of that structure.
|
| 35 |
+
|
| 36 |
+
# The simplest way if you are in the correct CWD when loading:
|
| 37 |
+
full_path_to_audio = os.path.join(os.getcwd(), relative_path_from_dataset) # This might not be right if rel path has '..'
|
| 38 |
+
full_path_to_audio = os.path.abspath(relative_path_from_dataset) # This is usually what you want if CWD is the intended base
|
| 39 |
+
|
| 40 |
+
print(f"Attempting to load from (resolved path): {full_path_to_audio}")
|
| 41 |
+
|
| 42 |
+
if os.path.exists(full_path_to_audio):
|
| 43 |
+
waveform, sample_rate = torchaudio.load(full_path_to_audio)
|
| 44 |
+
print(f"Loaded audio: waveform shape {waveform.shape}, sample rate {sample_rate}")
|
| 45 |
+
else:
|
| 46 |
+
print(f"Audio file NOT FOUND at resolved path: {full_path_to_audio}")
|
| 47 |
+
print("Ensure your Current Working Directory is set correctly so the relative path can be resolved.")
|
| 48 |
+
print(f"Alternatively, manually construct the absolute path if you know where the audio files are relative to a fixed base.")
|