1f commited on
Commit
19891ba
·
verified ·
1 Parent(s): ec8b32c

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +27 -0
  2. r1-a/dataset/ai2_arc.py +175 -0
  3. r1-a/dataset/alpaca.py +346 -0
  4. r1-a/dataset/commonsense.py +175 -0
  5. r1-a/dataset/examqa.py +440 -0
  6. r1-a/dataset/examqa_rewrite.py +487 -0
  7. r1-a/dataset/final_tts.py +316 -0
  8. r1-a/dataset/gsm8k.py +169 -0
  9. r1-a/dataset/gsm8k_with_audio/test/299.wav +3 -0
  10. r1-a/dataset/gsm8k_with_audio/test/301.wav +3 -0
  11. r1-a/dataset/gsm8k_with_audio/test/302.wav +3 -0
  12. r1-a/dataset/gsm8k_with_audio/test/314.wav +3 -0
  13. r1-a/dataset/gsm8k_with_audio/test/316.wav +3 -0
  14. r1-a/dataset/gsm8k_with_audio/test/350.wav +3 -0
  15. r1-a/dataset/gsm8k_with_audio/test/358.wav +3 -0
  16. r1-a/dataset/gsm8k_with_audio/test/359.wav +3 -0
  17. r1-a/dataset/gsm8k_with_audio/test/369.wav +3 -0
  18. r1-a/dataset/gsm8k_with_audio/test/372.wav +3 -0
  19. r1-a/dataset/gsm8k_with_audio/test/376.wav +3 -0
  20. r1-a/dataset/gsm8k_with_audio/test/385.wav +3 -0
  21. r1-a/dataset/gsm8k_with_audio/test/394.wav +3 -0
  22. r1-a/dataset/gsm8k_with_audio/test/395.wav +3 -0
  23. r1-a/dataset/gsm8k_with_audio/test/397.wav +3 -0
  24. r1-a/dataset/gsm8k_with_audio/test/400.wav +3 -0
  25. r1-a/dataset/gsm8k_with_audio/test/401.wav +3 -0
  26. r1-a/dataset/gsm8k_with_audio/test/447.wav +3 -0
  27. r1-a/dataset/gsm8k_with_audio/test/45.wav +3 -0
  28. r1-a/dataset/gsm8k_with_audio/test/450.wav +3 -0
  29. r1-a/dataset/gsm8k_with_audio/test/454.wav +3 -0
  30. r1-a/dataset/gsm8k_with_audio/test/457.wav +3 -0
  31. r1-a/dataset/gsm8k_with_audio/test/458.wav +3 -0
  32. r1-a/dataset/gsm8k_with_audio/test/459.wav +3 -0
  33. r1-a/dataset/gsm8k_with_audio/test/463.wav +3 -0
  34. r1-a/dataset/gsm8k_with_audio/test/465.wav +3 -0
  35. r1-a/dataset/gsm8k_with_audio/test/515.wav +3 -0
  36. r1-a/dataset/gsm8k_with_audio/test/877.wav +0 -0
  37. r1-a/dataset/gsm8k_with_audio/test/964.wav +0 -0
  38. r1-a/dataset/gsm8k_with_audio/test/final_dataset/dataset_info.json +20 -0
  39. r1-a/dataset/gsm8k_with_audio/test/final_dataset/state.json +13 -0
  40. r1-a/dataset/gsm8k_with_audio/test/progress.txt +1 -0
  41. r1-a/dataset/pkusafe.py +171 -0
  42. r1-a/dataset/pkusafe_tts.py +279 -0
  43. r1-a/dataset/retry_rewrite.py +442 -0
  44. r1-a/dataset/retts.py +559 -0
  45. r1-a/dataset/sciq.py +176 -0
  46. r1-a/dataset/shp.py +148 -0
  47. r1-a/dataset/shp_tts.py +494 -0
  48. r1-a/dataset/ultrachat.py +261 -0
  49. r1-a/dataset/ultrachat_tts.py +382 -0
  50. r1-a/prompt_only_examine.py +48 -0
.gitattributes CHANGED
@@ -33,3 +33,30 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ r1-a/dataset/gsm8k_with_audio/test/450.wav filter=lfs diff=lfs merge=lfs -text
37
+ r1-a/dataset/gsm8k_with_audio/test/45.wav filter=lfs diff=lfs merge=lfs -text
38
+ r1-a/dataset/gsm8k_with_audio/test/358.wav filter=lfs diff=lfs merge=lfs -text
39
+ r1-a/dataset/gsm8k_with_audio/test/369.wav filter=lfs diff=lfs merge=lfs -text
40
+ r1-a/dataset/gsm8k_with_audio/test/316.wav filter=lfs diff=lfs merge=lfs -text
41
+ r1-a/dataset/gsm8k_with_audio/test/454.wav filter=lfs diff=lfs merge=lfs -text
42
+ r1-a/dataset/gsm8k_with_audio/test/376.wav filter=lfs diff=lfs merge=lfs -text
43
+ r1-a/dataset/gsm8k_with_audio/test/395.wav filter=lfs diff=lfs merge=lfs -text
44
+ r1-a/dataset/gsm8k_with_audio/test/359.wav filter=lfs diff=lfs merge=lfs -text
45
+ r1-a/dataset/gsm8k_with_audio/test/447.wav filter=lfs diff=lfs merge=lfs -text
46
+ r1-a/dataset/gsm8k_with_audio/test/299.wav filter=lfs diff=lfs merge=lfs -text
47
+ r1-a/dataset/gsm8k_with_audio/test/302.wav filter=lfs diff=lfs merge=lfs -text
48
+ r1-a/dataset/gsm8k_with_audio/test/394.wav filter=lfs diff=lfs merge=lfs -text
49
+ r1-a/dataset/gsm8k_with_audio/test/350.wav filter=lfs diff=lfs merge=lfs -text
50
+ r1-a/dataset/gsm8k_with_audio/test/385.wav filter=lfs diff=lfs merge=lfs -text
51
+ r1-a/dataset/gsm8k_with_audio/test/401.wav filter=lfs diff=lfs merge=lfs -text
52
+ r1-a/dataset/gsm8k_with_audio/test/463.wav filter=lfs diff=lfs merge=lfs -text
53
+ r1-a/dataset/gsm8k_with_audio/test/458.wav filter=lfs diff=lfs merge=lfs -text
54
+ r1-a/dataset/gsm8k_with_audio/test/457.wav filter=lfs diff=lfs merge=lfs -text
55
+ r1-a/dataset/gsm8k_with_audio/test/515.wav filter=lfs diff=lfs merge=lfs -text
56
+ r1-a/dataset/gsm8k_with_audio/test/301.wav filter=lfs diff=lfs merge=lfs -text
57
+ r1-a/dataset/gsm8k_with_audio/test/372.wav filter=lfs diff=lfs merge=lfs -text
58
+ r1-a/dataset/gsm8k_with_audio/test/314.wav filter=lfs diff=lfs merge=lfs -text
59
+ r1-a/dataset/gsm8k_with_audio/test/397.wav filter=lfs diff=lfs merge=lfs -text
60
+ r1-a/dataset/gsm8k_with_audio/test/465.wav filter=lfs diff=lfs merge=lfs -text
61
+ r1-a/dataset/gsm8k_with_audio/test/459.wav filter=lfs diff=lfs merge=lfs -text
62
+ r1-a/dataset/gsm8k_with_audio/test/400.wav filter=lfs diff=lfs merge=lfs -text
r1-a/dataset/ai2_arc.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import torch
4
+ import torchaudio
5
+ from datasets import load_dataset, Dataset
6
+ import sys
7
+ from tqdm import tqdm
8
+
9
+ sys.path.append('/root/autodl-tmp/CosyVoice')
10
+ from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
11
+ from cosyvoice.utils.file_utils import load_wav
12
+
13
+ # ------------------------
14
+ # 配置参数
15
+ # ------------------------
16
+ COMMON_VOICE_LANGUAGE = "en"
17
+ DATASET_NAME = "ai2_arc"
18
+ OUTPUT_DATASET_PATH = './arc_easy_with_audio' # 输出目录
19
+ SAMPLE_RATE = 16000
20
+
21
+ # ------------------------
22
+ # 辅助函数
23
+ # ------------------------
24
+ def get_random_prompt(common_voice_dataset, sample_rate=SAMPLE_RATE):
25
+ """
26
+ 从 VoxPopuli (替代原 Common Voice) 数据集中随机抽取一条语音及对应文本作为 prompt。
27
+ """
28
+ idx = random.randint(0, len(common_voice_dataset) - 1)
29
+ sample = common_voice_dataset.select([idx])[0]
30
+ audio = sample['audio']
31
+ waveform = torch.tensor(audio['array'], dtype=torch.float32)
32
+ sr = audio['sampling_rate']
33
+ if sr != sample_rate:
34
+ resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate)
35
+ waveform = resampler(waveform)
36
+ return waveform.unsqueeze(0), sample['raw_text']
37
+
38
+ def text_to_audio(query_text, cosyvoice, common_voice_dataset, stream=False):
39
+ """
40
+ 利用 CosyVoice2 模型将输入文本转换为语音,采用随机 prompt 进行零样本推理。
41
+ """
42
+ try:
43
+ prompt_speech, prompt_text = get_random_prompt(common_voice_dataset, sample_rate=SAMPLE_RATE)
44
+
45
+ all_speech = []
46
+ for i, j in enumerate(cosyvoice.inference_zero_shot(
47
+ query_text,
48
+ prompt_text,
49
+ prompt_speech,
50
+ stream=stream,
51
+ text_frontend=False
52
+ )):
53
+ all_speech.append(j['tts_speech'])
54
+
55
+ # 将所有生成的语音片段拼接在一起
56
+ combined_speech = torch.cat(all_speech, dim=-1)
57
+ sample_rate_val = cosyvoice.sample_rate
58
+
59
+ return {
60
+ 'audio_tensor': combined_speech,
61
+ 'sample_rate': sample_rate_val
62
+ }
63
+ except Exception as e:
64
+ print(f"Error converting text to audio: {e}")
65
+ return None
66
+
67
+ def process_example(example, cosyvoice, common_voice_dataset, sample_rate=SAMPLE_RATE):
68
+ """
69
+ 针对 AI2 ARC 数据集中的单个样本进行 TTS 处理。
70
+ 在此示例中,仅对 sample['question'] 字段执行 TTS。
71
+ """
72
+ query = example['question']
73
+ audio_result = text_to_audio(query, cosyvoice, common_voice_dataset, stream=False)
74
+ if audio_result is not None:
75
+ return {
76
+ 'audio_tensor': audio_result['audio_tensor'],
77
+ 'sample_rate': audio_result['sample_rate']
78
+ }
79
+ else:
80
+ return None
81
+
82
+ # ------------------------
83
+ # 数据加载与模型初始化
84
+ # ------------------------
85
+ print("Loading VoxPopuli (as Common Voice) dataset...")
86
+ common_voice = load_dataset("facebook/voxpopuli", "en", split='train')
87
+ print(f"Total VoxPopuli {COMMON_VOICE_LANGUAGE} samples: {len(common_voice)}")
88
+
89
+ print("Initializing CosyVoice2 model...")
90
+ cosyvoice = CosyVoice2(
91
+ '/root/autodl-tmp/CosyVoice/pretrained_models/CosyVoice2-0.5B', # 替换为实际模型路径
92
+ load_jit=True,
93
+ load_trt=False,
94
+ fp16=False
95
+ )
96
+
97
+ print("Loading ARC-Challenge dataset...")
98
+ # 如果想处理 ARC-Easy,只需改为 "ARC-Easy"
99
+ dataset = load_dataset("allenai/ai2_arc", "ARC-Easy")
100
+
101
+ # 创建输出目录
102
+ os.makedirs(OUTPUT_DATASET_PATH, exist_ok=True)
103
+
104
+ # ------------------------
105
+ # 主处理循环
106
+ # ------------------------
107
+ final_dataset_dict = {} # 存放各 split 最终处理后的数据
108
+
109
+ for split_name, split_dataset in dataset.items():
110
+ print(f"Processing split: {split_name} with {len(split_dataset)} examples")
111
+ split_output_dir = os.path.join(OUTPUT_DATASET_PATH, split_name)
112
+ os.makedirs(split_output_dir, exist_ok=True)
113
+
114
+ # 用于断点续跑的进度记录
115
+ progress_file = os.path.join(split_output_dir, "progress.txt")
116
+ start_index = 0
117
+ if os.path.exists(progress_file):
118
+ try:
119
+ with open(progress_file, "r") as f:
120
+ start_index = int(f.read().strip())
121
+ print(f"Resuming split '{split_name}' from sample index {start_index}")
122
+ except Exception as e:
123
+ print(f"读取进度文件失败:{e}")
124
+
125
+ final_samples = []
126
+
127
+ # 遍历处理每条样本
128
+ for i in tqdm(range(len(split_dataset)), desc=f"Processing {split_name}"):
129
+ # 如果已处理过,就直接跳过并仅把已存在文件的元信息记入 final_samples
130
+ if i < start_index:
131
+ sample = split_dataset[i]
132
+ wav_path = os.path.join(split_output_dir, f"{i}.wav")
133
+ if os.path.exists(wav_path):
134
+ # 保留所有原始字段 + 音频路径
135
+ sample_dict = {k: sample[k] for k in sample.keys()}
136
+ sample_dict["audio_filepath"] = wav_path
137
+ final_samples.append(sample_dict)
138
+ continue
139
+
140
+ sample = split_dataset[i]
141
+ result = process_example(sample, cosyvoice, common_voice, sample_rate=SAMPLE_RATE)
142
+
143
+ if result is not None:
144
+ audio_tensor = result['audio_tensor']
145
+ if audio_tensor.dim() == 1:
146
+ audio_tensor = audio_tensor.unsqueeze(0)
147
+ sample_rate_val = result['sample_rate']
148
+
149
+ output_wav_path = os.path.join(split_output_dir, f"{i}.wav")
150
+ try:
151
+ torchaudio.save(output_wav_path, audio_tensor, sample_rate_val)
152
+ except Exception as e:
153
+ print(f"Failed to save wav for sample {i}: {e}")
154
+ continue
155
+
156
+ # 保留所有原始字段 + 生成的音频路径
157
+ sample_dict = {k: sample[k] for k in sample.keys()}
158
+ sample_dict["audio_filepath"] = output_wav_path
159
+ final_samples.append(sample_dict)
160
+ else:
161
+ print(f"Sample {i} processing failed, no audio generated.")
162
+
163
+ # 更新进度记录
164
+ with open(progress_file, "w") as f:
165
+ f.write(str(i + 1))
166
+
167
+ # 生成 Hugging Face Dataset 并落盘
168
+ final_dataset_obj = Dataset.from_list(final_samples)
169
+ final_dataset_save_path = os.path.join(split_output_dir, "final_dataset")
170
+ final_dataset_obj.save_to_disk(final_dataset_save_path)
171
+
172
+ print(f"Finished processing split: {split_name} with {len(final_samples)} final samples.")
173
+ final_dataset_dict[split_name] = final_dataset_obj
174
+
175
+ print("所有分割处理完毕,最终数据集已保存。")
r1-a/dataset/alpaca.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --- SET CUDA DEVICE ---
2
+ # Method 1: Set environment variable BEFORE importing torch/cosyvoice
3
+ # This makes only GPU 1 visible to the script, appearing as 'cuda:0' internally.
4
+ import os
5
+ os.environ["CUDA_VISIBLE_DEVICES"] = "1"
6
+ # --- End CUDA Device Setting ---
7
+
8
+ import random
9
+ import torch
10
+ import torchaudio
11
+ from datasets import load_dataset, Dataset, load_from_disk
12
+ import sys
13
+ from tqdm import tqdm
14
+ import time
15
+
16
+ # Check if the specified GPU is available after setting the environment variable
17
+ if not torch.cuda.is_available():
18
+ print("WARNING: CUDA is not available after setting CUDA_VISIBLE_DEVICES='1'. Check your PyTorch installation and GPU drivers.")
19
+ print("Attempting to run on CPU, but this will be very slow.")
20
+ # Decide if you want to exit or proceed on CPU
21
+ # sys.exit(1) # Uncomment to exit if GPU not found
22
+ effective_device = torch.device("cpu")
23
+ else:
24
+ # Since CUDA_VISIBLE_DEVICES is set to '1', the first *visible* device is cuda:0
25
+ effective_device = torch.device("cuda:0")
26
+ print(f"CUDA device visible to PyTorch: {torch.cuda.get_device_name(0)}") # Should show the name of GPU 1
27
+ print(f"Script will effectively run on: {effective_device}")
28
+
29
+
30
+ sys.path.append('/root/autodl-tmp/CosyVoice') # Make sure this path is correct
31
+ # Import CosyVoice *after* setting the environment variable
32
+ try:
33
+ from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
34
+ from cosyvoice.utils.file_utils import load_wav
35
+ except ImportError as e:
36
+ print(f"Error importing CosyVoice: {e}")
37
+ print("Please ensure the path '/root/autodl-tmp/CosyVoice' is correct and the library is installed.")
38
+ sys.exit(1)
39
+
40
+ # ------------------------
41
+ # 配置参数
42
+ # ------------------------
43
+ COMMON_VOICE_LANGUAGE = "en"
44
+ FILTERED_ALPACA_PATH = './alpaca_filtered_for_spoken_dialogue_v2'
45
+ SPLITS_TO_PROCESS = ['train']
46
+ OUTPUT_DATASET_PATH = './alpaca_filtered_spoken_with_output_audio' # Keep output path distinct
47
+ SAMPLE_RATE = 16000
48
+ MAX_TTS_RETRIES = 3
49
+ RETRY_DELAY_SECONDS = 2
50
+
51
+ # ------------------------
52
+ # 辅助函数 (No changes needed here, should run on the visible device)
53
+ # ------------------------
54
+ def get_random_prompt(common_voice_dataset, sample_rate=SAMPLE_RATE):
55
+ """
56
+ 从 VoxPopuli (替代原 Common Voice) 数据集中随机抽取一条语音及对应文本作为 prompt。
57
+ """
58
+ idx = random.randint(0, len(common_voice_dataset) - 1)
59
+ sample = common_voice_dataset.select([idx])[0]
60
+ audio = sample['audio']
61
+ waveform = torch.tensor(audio['array'], dtype=torch.float32) # Created on CPU
62
+ sr = audio['sampling_rate']
63
+ if sr != sample_rate:
64
+ if waveform.dim() > 1:
65
+ waveform = waveform.mean(dim=0)
66
+ resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate)
67
+ waveform = resampler(waveform)
68
+ if waveform.dim() == 1:
69
+ waveform = waveform.unsqueeze(0)
70
+ if waveform.numel() == 0 or not sample['raw_text']:
71
+ print("Warning: Got an empty prompt, trying again...")
72
+ return get_random_prompt(common_voice_dataset, sample_rate)
73
+ # Return CPU tensor, CosyVoice inference should handle moving it
74
+ return waveform, sample['raw_text']
75
+
76
+ def text_to_audio(query_text, cosyvoice, common_voice_dataset, stream=False, max_retries=MAX_TTS_RETRIES):
77
+ """
78
+ 利用 CosyVoice2 模型将输入文本转换为语音,采用随机 prompt 进行零样本推理。
79
+ Includes retry logic on failure. Assumes cosyvoice runs on the configured device.
80
+ """
81
+ last_exception = None
82
+ for attempt in range(max_retries):
83
+ try:
84
+ # prompt_speech is initially on CPU
85
+ prompt_speech, prompt_text = get_random_prompt(common_voice_dataset, sample_rate=SAMPLE_RATE)
86
+
87
+ all_speech = []
88
+ # cosyvoice.inference_zero_shot should internally use the GPU device it was initialized on
89
+ # (which should be the visible cuda:0, i.e., original cuda:1)
90
+ inference_generator = cosyvoice.inference_zero_shot(
91
+ query_text,
92
+ prompt_text,
93
+ prompt_speech, # Pass CPU tensor
94
+ stream=stream,
95
+ text_frontend=False
96
+ )
97
+ # Generated chunks 'tts_speech' will be on the GPU
98
+ for i, chunk in enumerate(inference_generator):
99
+ if 'tts_speech' in chunk and chunk['tts_speech'] is not None:
100
+ all_speech.append(chunk['tts_speech'])
101
+ else:
102
+ print(f"Warning: Chunk {i} missing 'tts_speech' for text '{query_text[:60]}...'")
103
+
104
+ if not all_speech:
105
+ # Clear GPU memory cache if an error occurs during generation
106
+ if torch.cuda.is_available(): torch.cuda.empty_cache()
107
+ raise ValueError("TTS inference finished but produced no audio chunks.")
108
+
109
+ # combined_speech is on GPU
110
+ combined_speech = torch.cat(all_speech, dim=-1)
111
+ sample_rate_val = cosyvoice.sample_rate
112
+
113
+ return {
114
+ # Return GPU tensor, will be moved to CPU before saving
115
+ 'audio_tensor': combined_speech,
116
+ 'sample_rate': sample_rate_val
117
+ }
118
+ except Exception as e:
119
+ last_exception = e
120
+ print(f"Error converting text to audio on attempt {attempt + 1}/{max_retries}: {e}")
121
+ print(f"Text: '{query_text[:100]}...'")
122
+ print(f"Prompt Text: '{prompt_text[:100]}...'")
123
+ # Clear GPU cache on error as well
124
+ if torch.cuda.is_available(): torch.cuda.empty_cache()
125
+ if attempt < max_retries - 1:
126
+ print(f"Retrying with a different prompt in {RETRY_DELAY_SECONDS}s...")
127
+ time.sleep(RETRY_DELAY_SECONDS)
128
+ else:
129
+ print(f"All {max_retries} TTS attempts failed.")
130
+
131
+ print(f"Failed to generate audio for text after {max_retries} attempts: '{query_text[:60]}...'")
132
+ print(f"Last error: {last_exception}")
133
+ return None
134
+
135
+ def process_example(example, cosyvoice, common_voice_dataset, sample_rate=SAMPLE_RATE):
136
+ """
137
+ 针对从磁盘加载的过滤后 Alpaca 数据集中的单个样本进行 TTS 处理。
138
+ Processes example['output'].
139
+ """
140
+ text_to_convert = example.get('instruction')+example.get('input')
141
+ if not text_to_convert or not isinstance(text_to_convert, str) or text_to_convert.strip() == "":
142
+ print(f"Warning: Skipping example due to missing or empty 'output' field: {example.keys()}")
143
+ return None
144
+
145
+ audio_result = text_to_audio(text_to_convert, cosyvoice, common_voice_dataset, stream=False)
146
+
147
+ if audio_result is not None:
148
+ audio_tensor = audio_result['audio_tensor'] # Still on GPU here
149
+ if audio_tensor.dim() == 1:
150
+ audio_tensor = audio_tensor.unsqueeze(0)
151
+ elif audio_tensor.dim() > 2:
152
+ print(f"Warning: Generated audio tensor has unexpected shape {audio_tensor.shape}. Attempting to flatten.")
153
+ audio_tensor = audio_tensor.view(1, -1)
154
+
155
+ if audio_tensor.numel() == 0:
156
+ print(f"Warning: Generated audio tensor is empty for output text: '{text_to_convert[:60]}...'")
157
+ # Clear GPU cache even for empty tensor? Maybe not needed.
158
+ return None
159
+
160
+ return {
161
+ 'audio_tensor': audio_tensor, # Return GPU tensor
162
+ 'sample_rate': audio_result['sample_rate']
163
+ }
164
+ else:
165
+ return None
166
+
167
+ # ------------------------
168
+ # 数据加载与模型初始化
169
+ # ------------------------
170
+ print("Loading VoxPopuli (as Common Voice) dataset for prompts...")
171
+ try:
172
+ # Load prompt dataset to CPU memory
173
+ common_voice = load_dataset("facebook/voxpopuli", "en", split='train')
174
+ print(f"Total VoxPopuli {COMMON_VOICE_LANGUAGE} samples: {len(common_voice)}")
175
+ if len(common_voice) == 0:
176
+ raise ValueError("VoxPopuli dataset loaded but contains no samples.")
177
+ except Exception as e:
178
+ print(f"Error loading VoxPopuli dataset: {e}")
179
+ sys.exit(1)
180
+
181
+
182
+ print("Initializing CosyVoice2 model...")
183
+ try:
184
+ # CosyVoice should automatically initialize on the visible device ('cuda:0', which is original 'cuda:1')
185
+ # No explicit device='cuda:1' needed here due to CUDA_VISIBLE_DEVICES
186
+ cosyvoice = CosyVoice2(
187
+ '/root/autodl-tmp/CosyVoice/pretrained_models/CosyVoice2-0.5B',
188
+ load_jit=True,
189
+ load_trt=False, # Ensure TRT is False if not set up for GPU 1
190
+ fp16=False # Check if GPU 1 supports FP16 well if you enable this
191
+ # device=effective_device # Usually not needed if CUDA_VISIBLE_DEVICES is set, but uncomment if CosyVoice requires it explicitly
192
+ )
193
+ print(f"CosyVoice model initialized. It should be using device: {effective_device}")
194
+ except Exception as e:
195
+ print(f"Error initializing CosyVoice2 model: {e}")
196
+ # Try to get more info if it's a CUDA error
197
+ if isinstance(e, RuntimeError) and 'CUDA' in str(e):
198
+ print("This might be a CUDA initialization error. Ensure GPU 1 is functional and has enough memory.")
199
+ sys.exit(1)
200
+
201
+ print(f"Loading pre-filtered Alpaca dataset(s) from disk: {FILTERED_ALPACA_PATH}")
202
+ dataset_dict = {}
203
+ loaded_splits_count = 0
204
+ for split_name in SPLITS_TO_PROCESS:
205
+ split_dir_name = f"{split_name}_dataset"
206
+ split_path = os.path.join(FILTERED_ALPACA_PATH, split_dir_name)
207
+ print(f"Attempting to load split '{split_name}' from: {split_path}")
208
+ try:
209
+ # Load dataset to CPU memory
210
+ split_dataset = load_from_disk(split_path)
211
+
212
+ if not split_dataset:
213
+ print(f"Warning: Dataset loaded from '{split_path}' is empty or invalid. Skipping this split.")
214
+ continue
215
+ dataset_dict[split_name] = split_dataset
216
+ print(f"Successfully loaded split '{split_name}' with {len(split_dataset)} examples.")
217
+ loaded_splits_count += 1
218
+ except FileNotFoundError:
219
+ print(f"Info: Filtered dataset split not found at '{split_path}'. Skipping this split.")
220
+ except Exception as e:
221
+ print(f"Error loading pre-filtered dataset split from '{split_path}': {e}. Skipping this split.")
222
+
223
+ if loaded_splits_count == 0:
224
+ print(f"Error: Could not load any dataset splits from '{FILTERED_ALPACA_PATH}' using splits '{SPLITS_TO_PROCESS}'.")
225
+ sys.exit(1)
226
+
227
+ # 创建输出目录
228
+ os.makedirs(OUTPUT_DATASET_PATH, exist_ok=True)
229
+
230
+ # ------------------------
231
+ # 主处理循环
232
+ # ------------------------
233
+ final_dataset_dict = {}
234
+
235
+ for split_name, split_dataset in dataset_dict.items():
236
+ print(f"\nProcessing loaded split: {split_name} with {len(split_dataset)} examples")
237
+ split_output_dir = os.path.join(OUTPUT_DATASET_PATH, split_name)
238
+ os.makedirs(split_output_dir, exist_ok=True)
239
+
240
+ progress_file = os.path.join(split_output_dir, "progress.txt")
241
+ start_index = 0
242
+ # ... (progress file reading logic remains the same)
243
+ if os.path.exists(progress_file):
244
+ try:
245
+ with open(progress_file, "r") as f:
246
+ content = f.read().strip()
247
+ if content:
248
+ start_index = int(content)
249
+ print(f"Resuming split '{split_name}' TTS from sample index {start_index}")
250
+ else:
251
+ print(f"Progress file '{progress_file}' is empty, starting TTS from index 0.")
252
+ start_index = 0
253
+ except ValueError:
254
+ print(f"Could not parse integer from progress file '{progress_file}'. Starting TTS from index 0.")
255
+ start_index = 0
256
+ except Exception as e:
257
+ print(f"Error reading progress file '{progress_file}': {e}. Starting TTS from index 0.")
258
+ start_index = 0
259
+
260
+
261
+ final_samples = []
262
+
263
+ pbar = tqdm(range(start_index, len(split_dataset)), desc=f"TTS on '{split_name}' output field", initial=start_index, total=len(split_dataset))
264
+ for i in pbar:
265
+ sample = split_dataset[i] # Sample data is on CPU
266
+ output_wav_path = os.path.join(split_output_dir, f"{i}.wav")
267
+
268
+ if os.path.exists(output_wav_path):
269
+ sample_dict = {k: sample[k] for k in sample.keys()}
270
+ sample_dict["audio_filepath"] = output_wav_path
271
+ final_samples.append(sample_dict)
272
+ with open(progress_file, "w") as f:
273
+ f.write(str(i + 1))
274
+ continue
275
+
276
+ # --- Perform TTS on the target device ---
277
+ result = process_example(sample, cosyvoice, common_voice, sample_rate=SAMPLE_RATE)
278
+
279
+ if result is not None:
280
+ audio_tensor = result['audio_tensor'] # Received tensor is on GPU
281
+ sample_rate_val = result['sample_rate']
282
+
283
+ try:
284
+ # --- Move tensor to CPU before saving ---
285
+ audio_tensor_save = audio_tensor.detach().cpu().to(torch.float32)
286
+ if audio_tensor_save.dim() == 1:
287
+ audio_tensor_save = audio_tensor_save.unsqueeze(0)
288
+ elif audio_tensor_save.dim() > 2:
289
+ audio_tensor_save = audio_tensor_save.view(1, -1)
290
+
291
+ torchaudio.save(output_wav_path, audio_tensor_save, sample_rate_val)
292
+
293
+ sample_dict = {k: sample[k] for k in sample.keys()}
294
+ sample_dict["audio_filepath"] = output_wav_path
295
+ final_samples.append(sample_dict)
296
+
297
+ # --- Explicitly delete GPU tensor and clear cache periodically? ---
298
+ # Can sometimes help prevent memory creep in long loops
299
+ del audio_tensor
300
+ # if i % 50 == 0: # Example: clear cache every 50 iterations
301
+ # if torch.cuda.is_available(): torch.cuda.empty_cache()
302
+
303
+ except Exception as e:
304
+ print(f"Failed to save wav for sample {i} ('output' field TTS) at {output_wav_path}: {e}")
305
+ # Clear cache on save error too, just in case
306
+ if torch.cuda.is_available(): torch.cuda.empty_cache()
307
+ else:
308
+ print(f"Sample {i} TTS failed after retries (Output Text: '{sample.get('output', 'N/A')[:60]}...'), no audio generated.")
309
+ # No tensor to delete if result is None
310
+
311
+ # Update progress file
312
+ with open(progress_file, "w") as f:
313
+ f.write(str(i + 1))
314
+
315
+ # --- Optional: Add more frequent cache clearing ---
316
+ # if i % 20 == 0 and torch.cuda.is_available(): # Clear more often if memory is tight
317
+ # torch.cuda.empty_cache()
318
+
319
+
320
+ # --- Final cache clear after finishing a split ---
321
+ if torch.cuda.is_available():
322
+ torch.cuda.empty_cache()
323
+
324
+ # ... (Saving final dataset logic remains the same)
325
+ if final_samples:
326
+ final_dataset_obj = Dataset.from_list(final_samples)
327
+ final_dataset_save_path = os.path.join(split_output_dir, "final_dataset")
328
+ try:
329
+ print(f"Saving final dataset for split '{split_name}' (with new audio paths) to {final_dataset_save_path}...")
330
+ os.makedirs(os.path.dirname(final_dataset_save_path), exist_ok=True)
331
+ final_dataset_obj.save_to_disk(final_dataset_save_path)
332
+ print(f"Finished processing split: {split_name}. Saved {len(final_samples)} samples with new audio paths for 'output' field.")
333
+ final_dataset_dict[split_name] = final_dataset_obj
334
+ except Exception as e:
335
+ print(f"Error saving final dataset for split '{split_name}' to disk: {e}")
336
+ else:
337
+ print(f"Finished processing split: {split_name}. No samples were successfully processed or saved.")
338
+
339
+
340
+ print("="*30)
341
+ if final_dataset_dict:
342
+ print(f"All specified splits processed. Final datasets saved in respective subdirectories within '{OUTPUT_DATASET_PATH}'.")
343
+ print(f"Processed splits: {list(final_dataset_dict.keys())}")
344
+ else:
345
+ print(f"Processing finished, but no final datasets were generated or saved in '{OUTPUT_DATASET_PATH}'. Check logs for errors.")
346
+ print("="*30)
r1-a/dataset/commonsense.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import torch
4
+ import torchaudio
5
+ from datasets import load_dataset, Dataset
6
+ import sys
7
+ from tqdm import tqdm
8
+
9
+ sys.path.append('/root/autodl-tmp/CosyVoice')
10
+ from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
11
+ from cosyvoice.utils.file_utils import load_wav
12
+
13
+ # ------------------------
14
+ # 配置参数
15
+ # ------------------------
16
+ COMMON_VOICE_LANGUAGE = "en"
17
+ DATASET_NAME = "commonsense_qa"
18
+ OUTPUT_DATASET_PATH = './commonsense_qa_with_audio' # 输出目录
19
+ SAMPLE_RATE = 16000
20
+
21
+ # ------------------------
22
+ # 辅助函数
23
+ # ------------------------
24
+ def get_random_prompt(common_voice_dataset, sample_rate=SAMPLE_RATE):
25
+ """
26
+ 从 VoxPopuli (此处替代原 Common Voice) 数据集中随机抽取一条语音及对应文本作为 prompt。
27
+ """
28
+ idx = random.randint(0, len(common_voice_dataset) - 1)
29
+ sample = common_voice_dataset.select([idx])[0]
30
+ audio = sample['audio']
31
+ waveform = torch.tensor(audio['array'], dtype=torch.float32)
32
+ sr = audio['sampling_rate']
33
+ if sr != sample_rate:
34
+ resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate)
35
+ waveform = resampler(waveform)
36
+ return waveform.unsqueeze(0), sample['raw_text']
37
+
38
+ def text_to_audio(query_text, cosyvoice, common_voice_dataset, stream=False):
39
+ """
40
+ 利用 CosyVoice2 模型将输入文本转换为语音,采用随机 prompt 进行零样本推理。
41
+ """
42
+ try:
43
+ prompt_speech, prompt_text = get_random_prompt(common_voice_dataset, sample_rate=SAMPLE_RATE)
44
+
45
+ all_speech = []
46
+ for i, j in enumerate(cosyvoice.inference_zero_shot(
47
+ query_text,
48
+ prompt_text,
49
+ prompt_speech,
50
+ stream=stream,
51
+ text_frontend=False
52
+ )):
53
+ all_speech.append(j['tts_speech'])
54
+
55
+ # 将所有生成的语音片段拼接在一起
56
+ combined_speech = torch.cat(all_speech, dim=-1)
57
+ sample_rate_val = cosyvoice.sample_rate
58
+
59
+ return {
60
+ 'audio_tensor': combined_speech,
61
+ 'sample_rate': sample_rate_val
62
+ }
63
+ except Exception as e:
64
+ print(f"Error converting text to audio: {e}")
65
+ return None
66
+
67
+ def process_example(example, cosyvoice, common_voice_dataset, sample_rate=SAMPLE_RATE):
68
+ """
69
+ 针对 Commonsense QA 数据集中的单个样本进行 TTS 处理。
70
+ 在此示例中,仅对 sample['question'] 字段执行 TTS。
71
+ """
72
+ query = example['question']
73
+ audio_result = text_to_audio(query, cosyvoice, common_voice_dataset, stream=False)
74
+ if audio_result is not None:
75
+ return {
76
+ 'audio_tensor': audio_result['audio_tensor'],
77
+ 'sample_rate': audio_result['sample_rate']
78
+ }
79
+ else:
80
+ return None
81
+
82
+ # ------------------------
83
+ # 数据加载与模型初始化
84
+ # ------------------------
85
+ print("Loading VoxPopuli (as Common Voice) dataset...")
86
+ common_voice = load_dataset("facebook/voxpopuli", "en", split='train')
87
+ print(f"Total VoxPopuli {COMMON_VOICE_LANGUAGE} samples: {len(common_voice)}")
88
+
89
+ print("Initializing CosyVoice2 model...")
90
+ cosyvoice = CosyVoice2(
91
+ '/root/autodl-tmp/CosyVoice/pretrained_models/CosyVoice2-0.5B', # 替换为实际模型路径
92
+ load_jit=True,
93
+ load_trt=False,
94
+ fp16=False
95
+ )
96
+
97
+ print("Loading Commonsense QA dataset...")
98
+ dataset = load_dataset("tau/commonsense_qa")
99
+ # 如果只想处理 train,可写成 dataset = load_dataset("tau/commonsense_qa", split="train")
100
+
101
+ # 创建输出目录
102
+ os.makedirs(OUTPUT_DATASET_PATH, exist_ok=True)
103
+
104
+ # ------------------------
105
+ # 主处理循环
106
+ # ------------------------
107
+ final_dataset_dict = {} # 存放各 split 最终处理后的数据
108
+
109
+ for split_name, split_dataset in dataset.items():
110
+ print(f"Processing split: {split_name} with {len(split_dataset)} examples")
111
+ split_output_dir = os.path.join(OUTPUT_DATASET_PATH, split_name)
112
+ os.makedirs(split_output_dir, exist_ok=True)
113
+
114
+ # 用于断点续跑的进度记录
115
+ progress_file = os.path.join(split_output_dir, "progress.txt")
116
+ start_index = 0
117
+ if os.path.exists(progress_file):
118
+ try:
119
+ with open(progress_file, "r") as f:
120
+ start_index = int(f.read().strip())
121
+ print(f"Resuming split '{split_name}' from sample index {start_index}")
122
+ except Exception as e:
123
+ print(f"读取进度文件失败:{e}")
124
+
125
+ final_samples = []
126
+
127
+ # 遍历处理每条样本
128
+ for i in tqdm(range(len(split_dataset)), desc=f"Processing {split_name}"):
129
+ # 如果已处理过,就直接跳过并仅把已存在文件的元信息记入 final_samples
130
+ if i < start_index:
131
+ sample = split_dataset[i]
132
+ wav_path = os.path.join(split_output_dir, f"{i}.wav")
133
+ if os.path.exists(wav_path):
134
+ # 保留所有原始字段 + 音频路径
135
+ sample_dict = {k: sample[k] for k in sample.keys()}
136
+ sample_dict["audio_filepath"] = wav_path
137
+ final_samples.append(sample_dict)
138
+ continue
139
+
140
+ sample = split_dataset[i]
141
+ result = process_example(sample, cosyvoice, common_voice, sample_rate=SAMPLE_RATE)
142
+
143
+ if result is not None:
144
+ audio_tensor = result['audio_tensor']
145
+ if audio_tensor.dim() == 1:
146
+ audio_tensor = audio_tensor.unsqueeze(0)
147
+ sample_rate_val = result['sample_rate']
148
+
149
+ output_wav_path = os.path.join(split_output_dir, f"{i}.wav")
150
+ try:
151
+ torchaudio.save(output_wav_path, audio_tensor, sample_rate_val)
152
+ except Exception as e:
153
+ print(f"Failed to save wav for sample {i}: {e}")
154
+ continue
155
+
156
+ # 保留所有原始字段 + 生成的音频路径
157
+ sample_dict = {k: sample[k] for k in sample.keys()}
158
+ sample_dict["audio_filepath"] = output_wav_path
159
+ final_samples.append(sample_dict)
160
+ else:
161
+ print(f"Sample {i} processing failed, no audio generated.")
162
+
163
+ # 更新进度记录
164
+ with open(progress_file, "w") as f:
165
+ f.write(str(i + 1))
166
+
167
+ # 生成 Hugging Face Dataset 并落盘
168
+ final_dataset_obj = Dataset.from_list(final_samples)
169
+ final_dataset_save_path = os.path.join(split_output_dir, "final_dataset")
170
+ final_dataset_obj.save_to_disk(final_dataset_save_path)
171
+
172
+ print(f"Finished processing split: {split_name} with {len(final_samples)} final samples.")
173
+ final_dataset_dict[split_name] = final_dataset_obj
174
+
175
+ print("所有分割处理完毕,最终数据集已保存。")
r1-a/dataset/examqa.py ADDED
@@ -0,0 +1,440 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --- SET CUDA DEVICE ---
2
+ # Method 1: Set environment variable BEFORE importing torch/cosyvoice
3
+ # This makes only GPU 1 visible to the script, appearing as 'cuda:0' internally.
4
+ import os
5
+ os.environ["CUDA_VISIBLE_DEVICES"] = "1" # <-- Keep your original setting
6
+ # --- End CUDA Device Setting ---
7
+
8
+ import random
9
+ import torch
10
+ import torchaudio
11
+ # Import load_from_disk to load the dataset saved by your LLM script
12
+ from datasets import load_dataset, Dataset, load_from_disk, Features, Value, Sequence, ClassLabel # Added Features etc. for robustness
13
+ import sys
14
+ from tqdm import tqdm
15
+ import time
16
+ import logging # Add logging
17
+ import json # For fallback saving
18
+
19
+ # Check if the specified GPU is available after setting the environment variable
20
+ if not torch.cuda.is_available():
21
+ print("ERROR: CUDA is not available after setting CUDA_VISIBLE_DEVICES. Cannot run TTS on GPU.")
22
+ print("Check your PyTorch installation, GPU drivers, and CUDA setup.")
23
+ sys.exit(1) # Exit if GPU is required and not found
24
+ else:
25
+ # Since CUDA_VISIBLE_DEVICES is set, the first *visible* device is cuda:0
26
+ effective_device = torch.device("cuda:0")
27
+ print(f"CUDA device visible to PyTorch: {torch.cuda.get_device_name(0)}")
28
+ print(f"Script will effectively run TTS inference on: {effective_device}")
29
+
30
+
31
+ sys.path.append('/root/autodl-tmp/CosyVoice') # Make sure this path is correct
32
+ # Import CosyVoice *after* setting the environment variable
33
+ try:
34
+ from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
35
+ from cosyvoice.utils.file_utils import load_wav
36
+ except ImportError as e:
37
+ print(f"Error importing CosyVoice: {e}")
38
+ print("Please ensure the path '/root/autodl-tmp/CosyVoice' is correct and the library is installed.")
39
+ sys.exit(1)
40
+
41
+ # Setup basic logging
42
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
43
+
44
+ # ------------------------
45
+ # 配置参数
46
+ # ------------------------
47
+ COMMON_VOICE_LANGUAGE = "en"
48
+ # --- Path to the dataset output by the LLM rephrasing script ---
49
+ REPHRASED_DATASET_PATH = './Multi-subject-RLVR_rephrased/train_processed_final' # <-- ADJUST IF YOUR PATH IS DIFFERENT
50
+ # --- Output path for THIS TTS script ---
51
+ TTS_OUTPUT_PATH = './Multi-subject-RLVR_rephrased_with_audio' # <-- New path for results
52
+ SAMPLE_RATE = 16000
53
+ MAX_TTS_RETRIES = 3
54
+ RETRY_DELAY_SECONDS = 2
55
+ # Define the assumed split name for directory structure (even if only one split)
56
+ ASSUMED_INPUT_SPLIT = "train"
57
+
58
+ # ------------------------
59
+ # 辅助函数 (No changes needed here, includes retry and uses visible GPU)
60
+ # ------------------------
61
+ def get_random_prompt(common_voice_dataset, sample_rate=SAMPLE_RATE):
62
+ """
63
+ 从 VoxPopuli (替代原 Common Voice) 数据集中随机抽取一条语音及对应文本作为 prompt。
64
+ """
65
+ idx = random.randint(0, len(common_voice_dataset) - 1)
66
+ sample = common_voice_dataset.select([idx])[0]
67
+ audio = sample['audio']
68
+ waveform = torch.tensor(audio['array'], dtype=torch.float32) # Created on CPU
69
+ sr = audio['sampling_rate']
70
+ if sr != sample_rate:
71
+ if waveform.dim() > 1:
72
+ waveform = waveform.mean(dim=0)
73
+ resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate)
74
+ waveform = resampler(waveform)
75
+ if waveform.dim() == 1:
76
+ waveform = waveform.unsqueeze(0)
77
+ if waveform.numel() == 0 or not sample['raw_text']:
78
+ logging.warning("Got an empty prompt, trying again...")
79
+ return get_random_prompt(common_voice_dataset, sample_rate)
80
+ # Return CPU tensor, CosyVoice inference should handle moving it
81
+ return waveform, sample['raw_text']
82
+
83
+ def text_to_audio(query_text, cosyvoice, common_voice_dataset, stream=False, max_retries=MAX_TTS_RETRIES):
84
+ """
85
+ 利用 CosyVoice2 模型将输入文本转换为语音,采用随机 prompt 进行零样本推理。
86
+ Includes retry logic on failure. Assumes cosyvoice runs on the configured device.
87
+ """
88
+ last_exception = None
89
+ for attempt in range(max_retries):
90
+ try:
91
+ prompt_speech, prompt_text = get_random_prompt(common_voice_dataset, sample_rate=SAMPLE_RATE)
92
+
93
+ all_speech = []
94
+ inference_generator = cosyvoice.inference_zero_shot(
95
+ query_text,
96
+ prompt_text,
97
+ prompt_speech, # Pass CPU tensor
98
+ stream=stream,
99
+ text_frontend=False
100
+ )
101
+ for i, chunk in enumerate(inference_generator):
102
+ if 'tts_speech' in chunk and chunk['tts_speech'] is not None:
103
+ all_speech.append(chunk['tts_speech'])
104
+ else:
105
+ logging.warning(f"TTS Chunk {i} missing 'tts_speech' for text '{query_text[:60]}...'")
106
+
107
+ if not all_speech:
108
+ if torch.cuda.is_available(): torch.cuda.empty_cache()
109
+ raise ValueError("TTS inference finished but produced no audio chunks.")
110
+
111
+ combined_speech = torch.cat(all_speech, dim=-1) # On GPU
112
+ sample_rate_val = cosyvoice.sample_rate
113
+
114
+ return {
115
+ 'audio_tensor': combined_speech, # Return GPU tensor
116
+ 'sample_rate': sample_rate_val
117
+ }
118
+ except Exception as e:
119
+ last_exception = e
120
+ logging.error(f"Error converting text to audio on attempt {attempt + 1}/{max_retries}: {e}", exc_info=True)
121
+ logging.error(f"Failed Text: '{query_text[:100]}...'")
122
+ logging.error(f"Prompt Text Used: '{prompt_text[:100]}...'")
123
+ if torch.cuda.is_available(): torch.cuda.empty_cache()
124
+ if attempt < max_retries - 1:
125
+ wait_time = RETRY_DELAY_SECONDS * (1.5 ** attempt) + random.uniform(0.5, 1.5)
126
+ logging.warning(f"Retrying TTS with a different prompt in {wait_time:.2f}s...")
127
+ time.sleep(wait_time)
128
+ else:
129
+ logging.error(f"All {max_retries} TTS attempts failed.")
130
+
131
+ logging.error(f"Failed to generate audio for text after {max_retries} attempts: '{query_text[:60]}...'")
132
+ logging.error(f"Last TTS error: {last_exception}")
133
+ return None
134
+
135
+ def process_example(example, cosyvoice, common_voice_dataset, sample_rate=SAMPLE_RATE):
136
+ """
137
+ 针对从磁盘加载的 LLM rephrased 数据集中的单个样本进行 TTS 处理。
138
+ Processes example['query_rephrased']. <--- Target the rephrased query
139
+ """
140
+ # --- Target the 'query_rephrased' field from the LLM output dataset ---
141
+ text_to_convert = example.get('query_rephrased') # <--- Use 'query_rephrased' field
142
+ if not text_to_convert or not isinstance(text_to_convert, str) or text_to_convert.strip() == "":
143
+ original_query = example.get('query', [{}])[0].get('content', 'Original Query Missing')[:50]
144
+ logging.warning(f"Skipping TTS for example due to missing or empty 'query_rephrased' field. Original query started with: '{original_query}...'. Status was: {example.get('query_rephrased_status','N/A')}")
145
+ return None
146
+
147
+ # --- Use the text_to_audio function with retry logic ---
148
+ audio_result = text_to_audio(text_to_convert, cosyvoice, common_voice_dataset, stream=False)
149
+
150
+ if audio_result is not None:
151
+ audio_tensor = audio_result['audio_tensor'] # Still on GPU here
152
+ if audio_tensor.dim() == 1:
153
+ audio_tensor = audio_tensor.unsqueeze(0)
154
+ elif audio_tensor.dim() > 2:
155
+ logging.warning(f"Generated audio tensor has unexpected shape {audio_tensor.shape}. Attempting to flatten.")
156
+ audio_tensor = audio_tensor.view(1, -1)
157
+
158
+ if audio_tensor.numel() == 0:
159
+ logging.warning(f"Generated audio tensor is empty for rephrased query: '{text_to_convert[:60]}...'")
160
+ return None
161
+
162
+ return {
163
+ 'audio_tensor': audio_tensor, # Return GPU tensor
164
+ 'sample_rate': audio_result['sample_rate']
165
+ }
166
+ else:
167
+ # text_to_audio already logged the failure
168
+ return None
169
+
170
+ # ------------------------
171
+ # 数据加载与模型初始化
172
+ # ------------------------
173
+ logging.info("Loading VoxPopuli (as Common Voice) dataset for prompts...")
174
+ try:
175
+ common_voice = load_dataset("facebook/voxpopuli", "en", split='train')
176
+ logging.info(f"Total VoxPopuli {COMMON_VOICE_LANGUAGE} samples: {len(common_voice)}")
177
+ if len(common_voice) == 0:
178
+ raise ValueError("VoxPopuli dataset loaded but contains no samples.")
179
+ except Exception as e:
180
+ logging.error(f"Error loading VoxPopuli dataset: {e}", exc_info=True)
181
+ sys.exit(1)
182
+
183
+
184
+ logging.info("Initializing CosyVoice2 model...")
185
+ try:
186
+ cosyvoice = CosyVoice2(
187
+ '/root/autodl-tmp/CosyVoice/pretrained_models/CosyVoice2-0.5B',
188
+ load_jit=True,
189
+ load_trt=False,
190
+ fp16=False # Consider setting to True if VRAM is an issue and you have FP16 support
191
+ )
192
+ logging.info(f"CosyVoice model initialized on effective device: {effective_device}")
193
+ except Exception as e:
194
+ logging.error(f"Error initializing CosyVoice2 model: {e}", exc_info=True)
195
+ sys.exit(1)
196
+
197
+ logging.info(f"Loading rephrased dataset from disk: {REPHRASED_DATASET_PATH}")
198
+ try:
199
+ # --- Load the single dataset saved by the LLM script ---
200
+ rephrased_dataset = load_from_disk(REPHRASED_DATASET_PATH)
201
+ if not rephrased_dataset:
202
+ raise ValueError(f"Dataset loaded from '{REPHRASED_DATASET_PATH}' is empty or invalid.")
203
+ # --- Wrap it in a dict to match the loop structure expecting splits ---
204
+ # Use the assumed split name as the key
205
+ dataset_dict = {ASSUMED_INPUT_SPLIT: rephrased_dataset}
206
+ logging.info(f"Successfully loaded dataset with {len(rephrased_dataset)} examples.")
207
+ except FileNotFoundError:
208
+ logging.error(f"Error: Rephrased dataset not found at '{REPHRASED_DATASET_PATH}'.")
209
+ logging.error("Please ensure the LLM rephrasing script ran successfully and saved data to the correct location.")
210
+ sys.exit(1)
211
+ except Exception as e:
212
+ logging.error(f"Error loading rephrased dataset from '{REPHRASED_DATASET_PATH}': {e}", exc_info=True)
213
+ sys.exit(1)
214
+
215
+ # 创建输出目录
216
+ os.makedirs(TTS_OUTPUT_PATH, exist_ok=True)
217
+
218
+ # ------------------------
219
+ # 主处理循环
220
+ # ------------------------
221
+ final_dataset_dict_for_tracking = {} # To track which final datasets were saved
222
+
223
+ # Iterate through the dictionary (will contain only one split, e.g., 'train')
224
+ for split_name, split_dataset in dataset_dict.items():
225
+ logging.info(f"\nProcessing split: {split_name} with {len(split_dataset)} examples for TTS")
226
+ # Output directory for *this* script's results (audio + final dataset)
227
+ split_output_dir = os.path.join(TTS_OUTPUT_PATH, split_name)
228
+ os.makedirs(split_output_dir, exist_ok=True)
229
+ logging.info(f"Audio files and final data for this split will be saved in: {split_output_dir}")
230
+
231
+ # 用于断点续跑的进度记录 (specific to this TTS process)
232
+ progress_file = os.path.join(split_output_dir, "tts_progress.txt") # Use specific name
233
+ start_index = 0
234
+ if os.path.exists(progress_file):
235
+ try:
236
+ with open(progress_file, "r") as f:
237
+ content = f.read().strip()
238
+ if content:
239
+ start_index = int(content)
240
+ logging.info(f"Resuming split '{split_name}' TTS from sample index {start_index}")
241
+ else:
242
+ logging.info(f"Progress file '{progress_file}' is empty, starting TTS from index 0.")
243
+ start_index = 0
244
+ except ValueError:
245
+ logging.warning(f"Could not parse integer from progress file '{progress_file}'. Starting TTS from index 0.")
246
+ start_index = 0
247
+ except Exception as e:
248
+ logging.error(f"Error reading progress file '{progress_file}': {e}. Starting TTS from index 0.")
249
+ start_index = 0
250
+ else:
251
+ logging.info(f"No progress file found at '{progress_file}'. Starting TTS from index 0.")
252
+
253
+
254
+ # --- [NEW] Section: Check and Save Already Completed Samples ---
255
+ already_processed_samples = []
256
+ if start_index > 0:
257
+ logging.info(f"Checking for already processed samples (audio files) up to index {start_index - 1}...")
258
+ # Use tqdm here for visibility if start_index is large
259
+ for j in tqdm(range(start_index), desc="Checking existing audio"):
260
+ potential_output_wav_path = os.path.join(split_output_dir, f"{j}.wav")
261
+ if os.path.exists(potential_output_wav_path):
262
+ try:
263
+ # Ensure index is valid before accessing
264
+ if j < len(split_dataset):
265
+ original_sample = split_dataset[j]
266
+ # Create a dict with all original keys + the existing audio path
267
+ completed_sample_dict = {k: original_sample[k] for k in original_sample.keys()}
268
+ completed_sample_dict["audio_filepath"] = potential_output_wav_path # Point to the existing audio
269
+ already_processed_samples.append(completed_sample_dict)
270
+ else:
271
+ logging.warning(f"Index {j} is out of bounds for the loaded dataset (size {len(split_dataset)}) while checking existing files. Skipping this index.")
272
+ except Exception as e:
273
+ logging.error(f"Error processing data for existing sample index {j}: {e}")
274
+
275
+ if already_processed_samples:
276
+ logging.info(f"Found {len(already_processed_samples)} samples with existing audio files before the resume point.")
277
+ # Define path for the dataset of already processed samples
278
+ already_processed_dataset_path = os.path.join(split_output_dir, "already_processed_dataset")
279
+ try:
280
+ logging.info(f"Saving these {len(already_processed_samples)} already processed samples to: {already_processed_dataset_path}")
281
+
282
+ # Define features based on the original dataset + the new audio_filepath column
283
+ original_features = split_dataset.features
284
+ new_features_dict = original_features.copy()
285
+ if "audio_filepath" not in new_features_dict:
286
+ new_features_dict["audio_filepath"] = Value('string')
287
+ new_features = Features(new_features_dict)
288
+
289
+ already_processed_dataset = Dataset.from_list(already_processed_samples, features=new_features)
290
+ already_processed_dataset.save_to_disk(already_processed_dataset_path)
291
+ logging.info(f"Successfully saved dataset of {len(already_processed_samples)} already processed samples.")
292
+ # Clear the list to free memory
293
+ del already_processed_samples
294
+ if torch.cuda.is_available(): torch.cuda.empty_cache() # Clear cache after potential large list processing
295
+ except Exception as e:
296
+ logging.error(f"Failed to create or save the dataset of already processed samples: {e}", exc_info=True)
297
+ # Keep already_processed_samples list in memory in case of save failure? Maybe not needed.
298
+ else:
299
+ logging.info("No existing audio files found for samples before the resume point (index 0 to {}).".format(start_index - 1))
300
+ # --- [END NEW] Section ---
301
+
302
+
303
+ # --- Main processing loop ---
304
+ final_samples = [] # List to hold ALL processed sample dictionaries for the FINAL dataset of this run
305
+ logging.info(f"Starting/Resuming TTS processing from index {start_index}...")
306
+ pbar = tqdm(range(start_index, len(split_dataset)), desc=f"TTS on '{split_name}' query_rephrased", initial=start_index, total=len(split_dataset))
307
+ for i in pbar:
308
+ try:
309
+ sample = split_dataset[i] # Sample data is on CPU
310
+ except IndexError:
311
+ logging.error(f"Index {i} is out of bounds for split_dataset (size {len(split_dataset)}). Stopping processing.")
312
+ break # Stop if we somehow go out of bounds
313
+
314
+ # Define path for the *new* audio file (or potentially existing one)
315
+ output_wav_path = os.path.join(split_output_dir, f"{i}.wav")
316
+
317
+ # Check if this specific TTS output file ALREADY exists
318
+ # This handles cases where the script stopped AFTER saving audio but BEFORE updating progress
319
+ # OR if files were somehow generated but progress file was lost/reset.
320
+ if os.path.exists(output_wav_path):
321
+ logging.debug(f"Audio file already exists for index {i} at {output_wav_path}. Skipping TTS, adding to final list.")
322
+ # Create the dict with all original keys + the existing audio path
323
+ sample_dict = {k: sample[k] for k in sample.keys()}
324
+ sample_dict["audio_filepath"] = output_wav_path # Point to the existing audio
325
+ final_samples.append(sample_dict) # Add to the list for the *final* dataset
326
+ # Update progress even if skipped due to existing file (important!)
327
+ with open(progress_file, "w") as f:
328
+ f.write(str(i + 1))
329
+ continue # Move to the next sample
330
+
331
+ # --- Perform TTS on the 'query_rephrased' field (UNCHANGED CORE LOGIC) ---
332
+ result = process_example(sample, cosyvoice, common_voice, sample_rate=SAMPLE_RATE)
333
+
334
+ if result is not None:
335
+ audio_tensor = result['audio_tensor'] # Received tensor is on GPU
336
+ sample_rate_val = result['sample_rate']
337
+
338
+ try:
339
+ # --- Move tensor to CPU before saving ---
340
+ audio_tensor_save = audio_tensor.detach().cpu().to(torch.float32)
341
+ if audio_tensor_save.dim() == 1:
342
+ audio_tensor_save = audio_tensor_save.unsqueeze(0)
343
+ elif audio_tensor_save.dim() > 2:
344
+ audio_tensor_save = audio_tensor_save.view(1, -1)
345
+
346
+ torchaudio.save(output_wav_path, audio_tensor_save, sample_rate_val)
347
+
348
+ # --- Preserve all original fields + add the NEW audio path ---
349
+ sample_dict = {k: sample[k] for k in sample.keys()}
350
+ sample_dict["audio_filepath"] = output_wav_path # Add the path to the new audio
351
+ final_samples.append(sample_dict) # Add to the list for the *final* dataset
352
+
353
+ # Explicitly delete GPU tensor and clear cache
354
+ del audio_tensor
355
+ del audio_tensor_save
356
+ if torch.cuda.is_available(): torch.cuda.empty_cache()
357
+
358
+
359
+ except Exception as e:
360
+ logging.error(f"Failed to save wav for sample {i} (TTS of query_rephrased) at {output_wav_path}: {e}", exc_info=True)
361
+ if torch.cuda.is_available(): torch.cuda.empty_cache() # Clear cache on save error too
362
+ else:
363
+ # process_example already logged the failure
364
+ logging.warning(f"Sample {i} TTS failed after retries (Rephrased Query: '{str(sample.get('query_rephrased', 'N/A'))[:60]}...'), no audio generated. This sample will NOT be included in the final dataset.")
365
+ # Decide whether to add sample without audio path or skip it
366
+ # Skipping for now, as audio is the goal. If you wanted to include failed ones:
367
+ # sample_dict = {k: sample[k] for k in sample.keys()}
368
+ # sample_dict["audio_filepath"] = None # Indicate missing audio
369
+ # final_samples.append(sample_dict)
370
+
371
+
372
+ # Update progress file after processing each sample (success or failure to ensure resume point advances)
373
+ # Make sure this is outside the 'if result is not None' block
374
+ with open(progress_file, "w") as f:
375
+ f.write(str(i + 1))
376
+
377
+ # Optional: More frequent cache clearing (Uncomment if needed)
378
+ # if i % 50 == 0 and torch.cuda.is_available():
379
+ # torch.cuda.empty_cache()
380
+ # logging.debug(f"Cleared CUDA cache at index {i}")
381
+
382
+ # --- Final cache clear after finishing the split ---
383
+ if torch.cuda.is_available():
384
+ logging.info("Clearing final CUDA cache for the split.")
385
+ torch.cuda.empty_cache()
386
+
387
+ # --- Save the final dataset object for this split (contains items found + newly generated) ---
388
+ if final_samples:
389
+ # Define features based on the original dataset + the new audio_filepath column for the final dataset
390
+ final_features_dict = split_dataset.features.copy()
391
+ if "audio_filepath" not in final_features_dict:
392
+ final_features_dict["audio_filepath"] = Value('string')
393
+ final_features = Features(final_features_dict)
394
+
395
+ try:
396
+ logging.info(f"Attempting to create final dataset object from {len(final_samples)} collected samples...")
397
+ final_dataset_obj = Dataset.from_list(final_samples, features=final_features)
398
+
399
+ # Save the final dataset object inside the split's output directory
400
+ final_dataset_save_path = os.path.join(split_output_dir, "final_dataset") # Name for the complete dataset
401
+
402
+ logging.info(f"Saving final dataset object for split '{split_name}' (with audio paths) to {final_dataset_save_path}...")
403
+ os.makedirs(os.path.dirname(final_dataset_save_path), exist_ok=True) # Should exist, but safety check
404
+ final_dataset_obj.save_to_disk(final_dataset_save_path)
405
+ logging.info(f"Finished processing split: {split_name}. Saved {len(final_samples)} samples in the final dataset at '{final_dataset_save_path}'.")
406
+ final_dataset_dict_for_tracking[split_name] = final_dataset_obj # Keep track if needed
407
+
408
+ except Exception as e:
409
+ logging.error(f"Error creating or saving final dataset object for split '{split_name}': {e}", exc_info=True)
410
+ logging.error("Attempting to save final_samples list as JSON Lines as a fallback...")
411
+ fallback_path = os.path.join(split_output_dir, "final_samples_fallback.jsonl") # Use distinct fallback name
412
+ try:
413
+ with open(fallback_path, 'w', encoding='utf-8') as f:
414
+ for item in final_samples:
415
+ # Basic serialization attempt
416
+ serializable_item = {}
417
+ for k, v in item.items():
418
+ if isinstance(v, torch.Tensor):
419
+ serializable_item[k] = f"Tensor data (shape: {v.shape})" # Placeholder
420
+ elif isinstance(v, (dict, list, str, int, float, bool, type(None))):
421
+ serializable_item[k] = v
422
+ else:
423
+ serializable_item[k] = str(v) # Attempt string conversion for others
424
+ f.write(json.dumps(serializable_item) + '\n')
425
+ logging.info(f"Fallback JSON Lines saved to {fallback_path}")
426
+ except Exception as json_e:
427
+ logging.error(f"Fallback JSON save failed: {json_e}", exc_info=True)
428
+
429
+ else:
430
+ logging.warning(f"Finished processing split: {split_name}. No samples were successfully processed or found with existing audio during this run to add to the final dataset.")
431
+
432
+
433
+ print("="*30)
434
+ if final_dataset_dict_for_tracking:
435
+ logging.info(f"All specified splits processed for TTS. Final datasets saved in respective 'final_dataset' subdirectories within '{TTS_OUTPUT_PATH}'.")
436
+ logging.info(f"Processed splits where final datasets were generated: {list(final_dataset_dict_for_tracking.keys())}")
437
+ logging.info("Additionally, if resuming, datasets containing only the samples processed *before* this run may have been saved in 'already_processed_dataset' subdirectories.")
438
+ else:
439
+ logging.warning(f"TTS processing finished, but no final datasets were generated or saved in the 'final_dataset' subdirectories within '{TTS_OUTPUT_PATH}'. Check logs for errors. Pre-existing data might be in 'already_processed_dataset' if resuming.")
440
+ print("="*30)
r1-a/dataset/examqa_rewrite.py ADDED
@@ -0,0 +1,487 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import http.client
3
+ import json
4
+ import time
5
+ import random
6
+ from datasets import load_dataset, Dataset, DatasetDict, Features, Value
7
+ from tqdm.auto import tqdm
8
+ import sys
9
+ import logging
10
+ import getpass
11
+ import signal
12
+ import socket
13
+ import concurrent.futures
14
+ from concurrent.futures import ThreadPoolExecutor
15
+ import argparse # For command-line arguments
16
+
17
+ # --- Configuration --- (Mostly Same)
18
+ DATASET_NAME = "virtuoussy/Multi-subject-RLVR"
19
+ DATASET_SPLIT = "train"
20
+ API_HOST = "api2.aigcbest.top"
21
+ API_PATH = "/v1/chat/completions"
22
+ LLM_MODEL = "gpt-4.1-mini"
23
+ API_KEY = os.environ.get('AIGCBEST_API_KEY', "sk-U15cDXxI0bboL6iH4Hymzl30ws6oWzazWe1Ndwq9QtiPUEgI") # Simplified API Key Get
24
+ if not API_KEY or API_KEY == "YOUR_API_KEY_HERE":
25
+ print("API Key is not set correctly. Please set the AIGCBEST_API_KEY environment variable or replace the placeholder.")
26
+ sys.exit(1)
27
+
28
+ OUTPUT_DIR = f"./{DATASET_NAME.split('/')[-1]}_rephrased"
29
+ # Define the path where the *potentially incomplete* processed dataset exists
30
+ PROCESSED_DATA_PATH = os.path.join(OUTPUT_DIR, f"{DATASET_SPLIT}_processed")
31
+ # Define where the *final, retried* dataset will be saved
32
+ FINAL_OUTPUT_PATH = os.path.join(OUTPUT_DIR, f"{DATASET_SPLIT}_processed_final") # Save to new location initially
33
+
34
+ BATCH_SAVE_SIZE = 500 # How often to save intermediate progress *during retry*
35
+ MAX_WORKERS = 20
36
+ REQUEST_DELAY_SECONDS = 0.15
37
+ MAX_RETRIES = 3
38
+
39
+ # Setup logging
40
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
41
+ logging.getLogger("datasets").setLevel(logging.WARNING)
42
+ logging.getLogger("huggingface_hub").setLevel(logging.WARNING)
43
+
44
+ # --- LLM API Function (call_llm_api) ---
45
+ # Use the robust version from the previous answer
46
+ def call_llm_api(original_question, api_key, host, path, model, retries=MAX_RETRIES):
47
+ system_prompt = (
48
+ "You are an expert linguist specializing in converting structured prompts or "
49
+ "fill-in-the-blank problems into natural, spoken-language questions suitable for "
50
+ "text-to-speech (TTS). Your goal is to make the question sound like how a person "
51
+ "would naturally ask it. "
52
+ "If the input is a fill-in-the-blank problem (e.g., contains '-----'), "
53
+ "rephrase it as a direct question asking for the missing information. "
54
+ "Keep the core meaning, mathematical context, variables, and numbers exactly the same. "
55
+ "Focus only on rephrasing the *user's question* part provided. "
56
+ "Output *only* the rephrased question, without any introductory phrases like 'Here's the rephrased question:'."
57
+ )
58
+ payload = json.dumps({
59
+ "model": model,
60
+ "messages": [
61
+ {"role": "system", "content": system_prompt},
62
+ {"role": "user", "content": original_question}
63
+ ],
64
+ })
65
+ headers = {
66
+ 'Accept': 'application/json',
67
+ 'Authorization': f'Bearer {api_key}',
68
+ 'User-Agent': 'HuggingFace Dataset Processing Script (Retry Mode)',
69
+ 'Content-Type': 'application/json'
70
+ }
71
+ time.sleep(random.uniform(REQUEST_DELAY_SECONDS * 0.8, REQUEST_DELAY_SECONDS * 1.2))
72
+
73
+ for attempt in range(retries):
74
+ logging.debug(f"API call attempt {attempt + 1}/{retries} for: {original_question[:50]}...")
75
+ try:
76
+ conn = http.client.HTTPSConnection(host, timeout=60)
77
+ conn.request("POST", path, payload, headers)
78
+ res = conn.getresponse()
79
+ status = res.status
80
+ data = res.read()
81
+ conn.close()
82
+
83
+ if status == 200:
84
+ response_json = json.loads(data.decode("utf-8"))
85
+ if response_json.get("choices") and len(response_json["choices"]) > 0:
86
+ message = response_json["choices"][0].get("message")
87
+ if message and message.get("content"):
88
+ rephrased = message["content"].strip()
89
+ if len(rephrased) > 1 and ((rephrased.startswith('"') and rephrased.endswith('"')) or \
90
+ (rephrased.startswith("'") and rephrased.endswith("'"))):
91
+ rephrased = rephrased[1:-1]
92
+ if rephrased and rephrased.strip().lower() != original_question.strip().lower():
93
+ logging.debug(f"Successfully rephrased: {rephrased[:80]}...")
94
+ return rephrased
95
+ elif not rephrased:
96
+ logging.warning(f"LLM returned empty response for: {original_question[:50]}...")
97
+ return None
98
+ else:
99
+ logging.warning(f"LLM returned identical response for: {original_question[:50]}...")
100
+ return None # Treat identical as failure for rephrasing
101
+ logging.error(f"Unexpected API response structure: {data.decode('utf-8')}")
102
+ return None
103
+ elif status == 429:
104
+ retry_after_header = res.getheader('Retry-After', '5')
105
+ try: wait_time = int(retry_after_header)
106
+ except ValueError: wait_time = REQUEST_DELAY_SECONDS * (2 ** attempt) + random.uniform(1, 5)
107
+ logging.warning(f"Rate limit exceeded (HTTP {status}). Retrying after {wait_time:.2f} seconds...")
108
+ time.sleep(wait_time)
109
+ elif status >= 500:
110
+ wait_time = REQUEST_DELAY_SECONDS * (2 ** attempt) + random.uniform(1, 5)
111
+ logging.warning(f"Server error (HTTP {status}). Retrying after {wait_time:.2f} seconds...")
112
+ time.sleep(wait_time)
113
+ else:
114
+ logging.error(f"API Client Error: Status {status}, Response: {data.decode('utf-8')}")
115
+ return None
116
+ except (http.client.HTTPException, ConnectionError, socket.gaierror, TimeoutError, socket.timeout) as e:
117
+ logging.error(f"Network/HTTP error during API call: {e}. Attempt {attempt + 1}/{retries}")
118
+ if attempt + 1 == retries: return None
119
+ wait_time = REQUEST_DELAY_SECONDS * (1.5 ** attempt) + random.uniform(1, 3)
120
+ logging.warning(f"Waiting {wait_time:.2f} seconds before retry...")
121
+ time.sleep(wait_time)
122
+ except json.JSONDecodeError as e:
123
+ logging.error(f"Failed to decode API response: {e}. Response snippet: {data[:200]}")
124
+ if attempt + 1 == retries: return None
125
+ wait_time = REQUEST_DELAY_SECONDS * (2 ** attempt) + random.uniform(1, 5)
126
+ time.sleep(wait_time) # Wait before next attempt
127
+ except Exception as e:
128
+ logging.error(f"An unexpected error occurred during API call: {e}", exc_info=True)
129
+ if attempt + 1 == retries: return None
130
+ wait_time = REQUEST_DELAY_SECONDS * (1.5 ** attempt) + random.uniform(1, 3)
131
+ logging.warning(f"Waiting {wait_time:.2f} seconds before retry...")
132
+ time.sleep(wait_time)
133
+
134
+ logging.error(f"API call failed after {retries} retries for: {original_question[:50]}...")
135
+ return None
136
+
137
+
138
+ # --- Dataset Processing Function (rephrase_query_entry) ---
139
+ # Same as before, returns the full dictionary with status
140
+ def rephrase_query_entry(example):
141
+ processed_example = example.copy()
142
+ # Ensure status field exists, default to unprocessed if missing
143
+ if 'query_rephrased_status' not in processed_example:
144
+ processed_example['query_rephrased_status'] = 'unprocessed'
145
+
146
+ original_query_list = example.get("query")
147
+
148
+ # --- Input Validation ---
149
+ if original_query_list is None:
150
+ processed_example['query_rephrased_status'] = 'skipped_missing_query_column'
151
+ processed_example['query_rephrased'] = None
152
+ return processed_example
153
+ if not isinstance(original_query_list, list):
154
+ processed_example['query_rephrased_status'] = 'skipped_query_not_list'
155
+ processed_example['query_rephrased'] = None
156
+ return processed_example
157
+ if not original_query_list:
158
+ processed_example['query_rephrased_status'] = 'skipped_query_list_empty'
159
+ processed_example['query_rephrased'] = None
160
+ return processed_example
161
+
162
+ # --- Find User Question ---
163
+ user_question = None
164
+ for i, message in enumerate(original_query_list):
165
+ if isinstance(message, dict) and message.get("role") == "user":
166
+ content = message.get("content")
167
+ if isinstance(content, str) and content.strip():
168
+ user_question = content
169
+ break
170
+ else:
171
+ processed_example['query_rephrased_status'] = 'skipped_invalid_user_content'
172
+ processed_example['query_rephrased'] = None
173
+ return processed_example
174
+
175
+ if not user_question:
176
+ processed_example['query_rephrased_status'] = 'skipped_no_user_content_found'
177
+ processed_example['query_rephrased'] = None
178
+ return processed_example
179
+
180
+ # --- Call LLM API ---
181
+ logging.info(f"Attempting to rephrase: {user_question[:60]}...") # Log retry attempt
182
+ rephrased_query_content = call_llm_api(user_question, API_KEY, API_HOST, API_PATH, LLM_MODEL)
183
+
184
+ # --- Update Example Based on API Result ---
185
+ if rephrased_query_content:
186
+ logging.debug(f"Rephrased '{user_question[:30]}...' to '{rephrased_query_content[:30]}...'")
187
+ processed_example["query_rephrased"] = rephrased_query_content
188
+ processed_example['query_rephrased_status'] = 'success_retried' # New status for successful retry
189
+ else:
190
+ logging.warning(f"Retry failed for user question: {user_question[:50]}...")
191
+ # Keep existing rephrased content (likely None) but update status
192
+ processed_example['query_rephrased_status'] = 'failed_llm_retry' # New status for failed retry
193
+
194
+ return processed_example
195
+
196
+
197
+ # --- Function to Save Progress ---
198
+ # Saves the *entire list* of dictionaries
199
+ def save_final_dataset(data_list, output_path):
200
+ """Saves the final list of processed data dictionaries."""
201
+ if not data_list:
202
+ logging.info("No data provided for saving.")
203
+ return False
204
+ logging.info(f"Attempting to save {len(data_list)} final examples to {output_path}...")
205
+ try:
206
+ # Define features explicitly to handle potential Nones and ensure consistency
207
+ # Adjust types based on your actual dataset structure
208
+ features = Features({
209
+ 'query': [{'role': Value(dtype='string', id=None), 'content': Value(dtype='string', id=None)}],
210
+ 'query_rephrased': Value(dtype='string', id=None), # Allow nulls
211
+ 'query_rephrased_status': Value(dtype='string', id=None), # Allow nulls
212
+ # Add other columns from your original dataset here...
213
+ # Example: 'answer': Value(dtype='string', id=None),
214
+ # Example: 'subject': Value(dtype='string', id=None),
215
+ # IMPORTANT: List all columns present in your loaded dataset
216
+ 'query_code': Value(dtype='string', id=None),
217
+ 'answer': Value(dtype='string', id=None),
218
+ 'answer_code': Value(dtype='string', id=None),
219
+ 'subject': Value(dtype='string', id=None),
220
+ 'grade': Value(dtype='string', id=None),
221
+ 'source': Value(dtype='string', id=None),
222
+ 'split': Value(dtype='string', id=None),
223
+ '__index_level_0__': Value(dtype='int64', id=None) # Check if this column exists
224
+ })
225
+
226
+ # Clean data slightly - replace python None with "" for string fields if needed by Arrow
227
+ # or ensure feature definition handles nulls correctly (Value(dtype='string', id=None) should)
228
+ # cleaned_data_list = []
229
+ # for item in data_list:
230
+ # cleaned_item = item.copy()
231
+ # for key, feature_type in features.items():
232
+ # if isinstance(feature_type, Value) and feature_type.dtype == 'string':
233
+ # if cleaned_item.get(key) is None:
234
+ # cleaned_item[key] = "" # Or keep None if schema allows
235
+ # cleaned_data_list.append(cleaned_item)
236
+
237
+
238
+ # Use the original list directly if schema handles None
239
+ processed_dataset = Dataset.from_list(list(data_list), features=features)
240
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
241
+ processed_dataset.save_to_disk(output_path)
242
+ logging.info(f"Successfully saved final dataset ({len(data_list)} items) to {output_path}")
243
+ return True
244
+ except Exception as e:
245
+ logging.error(f"Failed to save final dataset to {output_path}: {e}", exc_info=True)
246
+ # Try saving as JSON as a fallback
247
+ fallback_json_path = output_path + ".jsonl"
248
+ logging.warning(f"Attempting fallback save to JSON Lines file: {fallback_json_path}")
249
+ try:
250
+ with open(fallback_json_path, 'w', encoding='utf-8') as f:
251
+ for item in data_list:
252
+ f.write(json.dumps(item, ensure_ascii=False) + '\n')
253
+ logging.info(f"Successfully saved fallback JSON Lines file to {fallback_json_path}")
254
+ except Exception as json_e:
255
+ logging.error(f"Fallback JSON save also failed: {json_e}", exc_info=True)
256
+ return False
257
+
258
+
259
+ # --- Helper to get original user query ---
260
+ def get_user_query(example):
261
+ """Extracts the user query content from the 'query' list."""
262
+ query_list = example.get("query")
263
+ if isinstance(query_list, list):
264
+ for message in query_list:
265
+ if isinstance(message, dict) and message.get("role") == "user":
266
+ content = message.get("content")
267
+ if isinstance(content, str) and content.strip():
268
+ return content
269
+ return None
270
+
271
+
272
+ # --- Function to Check if Retry is Needed ---
273
+ def needs_retry(example):
274
+ """Determines if an example needs reprocessing based on its current state."""
275
+ status = example.get('query_rephrased_status')
276
+ rephrased_text = example.get('query_rephrased')
277
+
278
+ # Condition 1: Explicit failure status from previous (new script) run
279
+ if status in ['failed_llm_call', 'failed_llm_retry', 'failed_processing_exception']:
280
+ return True
281
+
282
+ # Condition 2: Certain 'skipped' statuses might warrant a retry (optional, adjust as needed)
283
+ # For example, if the user content was invalid originally, retrying won't help.
284
+ # if status in ['skipped_no_user_content_found']: # Decide if these should be retried
285
+ # return True
286
+
287
+ # Condition 3: Status indicates success OR status is missing/old,
288
+ # BUT the rephrased text is missing or empty. This catches failures
289
+ # from the *old* script or inconsistent states.
290
+ if rephrased_text is None or not str(rephrased_text).strip():
291
+ # Don't retry if it was intentionally skipped due to bad input
292
+ if status not in ['skipped_missing_query_column', 'skipped_query_not_list',
293
+ 'skipped_query_list_empty', 'skipped_invalid_user_content',
294
+ 'skipped_no_user_content_found']:
295
+ return True
296
+
297
+ # Condition 4 (Optional but recommended): Check if rephrased text is identical to original user query
298
+ # This requires extracting the original query here.
299
+ # original_user_query = get_user_query(example)
300
+ # if original_user_query and isinstance(rephrased_text, str) and \
301
+ # rephrased_text.strip().lower() == original_user_query.strip().lower():
302
+ # # Check status first - if it was intentionally skipped, don't retry
303
+ # if status not in ['skipped_missing_query_column', 'skipped_query_not_list',
304
+ # 'skipped_query_list_empty', 'skipped_invalid_user_content',
305
+ # 'skipped_no_user_content_found']:
306
+ # logging.debug(f"Identified identical query/rephrased text for retry: {original_user_query[:50]}...")
307
+ # return True
308
+
309
+
310
+ # Default: No retry needed
311
+ return False
312
+
313
+
314
+ # --- Main Execution ---
315
+ if __name__ == "__main__":
316
+ start_time = time.time()
317
+ logging.info("======================================================")
318
+ logging.info(f" Starting Dataset Processing Script in RETRY MODE")
319
+ logging.info("======================================================")
320
+ logging.info(f"Dataset: {DATASET_NAME}, Split: {DATASET_SPLIT}")
321
+ logging.info(f"Loading existing processed data from: {PROCESSED_DATA_PATH}")
322
+ logging.info(f"Final output will be saved to: {FINAL_OUTPUT_PATH}")
323
+ logging.info(f"Max concurrent workers: {MAX_WORKERS}")
324
+
325
+ # --- Load Existing Processed Dataset ---
326
+ if not os.path.exists(PROCESSED_DATA_PATH):
327
+ logging.error(f"Existing processed data not found at '{PROCESSED_DATA_PATH}'. Cannot run in retry mode.")
328
+ sys.exit(1)
329
+
330
+ logging.info(f"Loading existing dataset from {PROCESSED_DATA_PATH}...")
331
+ try:
332
+ # Load the dataset saved by the previous script run
333
+ existing_dataset = Dataset.load_from_disk(PROCESSED_DATA_PATH)
334
+ # Convert to list of dictionaries for easier modification access by index
335
+ # Be mindful of memory usage for very large datasets
336
+ results_list = existing_dataset.to_list()
337
+ total_examples = len(results_list)
338
+ logging.info(f"Loaded {total_examples} examples.")
339
+ # Ensure essential columns exist, add them if missing from old format
340
+ for i in range(total_examples):
341
+ if 'query_rephrased' not in results_list[i]:
342
+ results_list[i]['query_rephrased'] = None
343
+ if 'query_rephrased_status' not in results_list[i]:
344
+ results_list[i]['query_rephrased_status'] = 'unknown_original_status'
345
+
346
+ except Exception as e:
347
+ logging.error(f"Failed to load existing dataset from {PROCESSED_DATA_PATH}: {e}", exc_info=True)
348
+ sys.exit(1)
349
+
350
+
351
+ # --- Identify Indices to Retry ---
352
+ indices_to_retry = [
353
+ i for i, example in enumerate(results_list) if needs_retry(example)
354
+ ]
355
+ num_to_retry = len(indices_to_retry)
356
+
357
+ if num_to_retry == 0:
358
+ logging.info("No examples found needing retry based on the criteria.")
359
+ logging.info(f"The dataset at {PROCESSED_DATA_PATH} is considered final.")
360
+ # Optional: You might still want to save it to FINAL_OUTPUT_PATH for consistency
361
+ # if not os.path.exists(FINAL_OUTPUT_PATH):
362
+ # save_final_dataset(results_list, FINAL_OUTPUT_PATH)
363
+ sys.exit(0)
364
+
365
+ logging.info(f"Identified {num_to_retry} examples to retry out of {total_examples}.")
366
+
367
+ # --- Prepare for Concurrent Retries ---
368
+ processed_count_in_retry = 0
369
+ # We don't need batch saving in the same way, but can update the list in memory
370
+ # A temporary dictionary to store results from futures before updating the main list
371
+ retry_results_dict = {}
372
+
373
+ logging.info("Starting concurrent processing for examples needing retry...")
374
+
375
+ with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
376
+ # Submit jobs only for the indices needing retry
377
+ # Pass the *specific example dictionary* to the function
378
+ futures = {
379
+ executor.submit(rephrase_query_entry, results_list[i]): i
380
+ for i in indices_to_retry
381
+ }
382
+
383
+ try:
384
+ pbar = tqdm(total=num_to_retry, desc="Retrying failed examples", unit="example")
385
+ for future in concurrent.futures.as_completed(futures):
386
+ original_index = futures[future] # Get the index in the full results_list
387
+ try:
388
+ # Get the updated dictionary result from the retry attempt
389
+ updated_example_dict = future.result()
390
+ # Store the result temporarily, keyed by original index
391
+ retry_results_dict[original_index] = updated_example_dict
392
+ pbar.set_postfix({"LastStatus": updated_example_dict.get('query_rephrased_status', 'N/A')}, refresh=True)
393
+
394
+ except Exception as exc:
395
+ # Catch errors *during* the retry processing itself
396
+ logging.error(f'Retry for example index {original_index} generated an exception: {exc}', exc_info=True)
397
+ # Create a placeholder indicating the retry attempt failed due to an exception
398
+ error_placeholder = results_list[original_index].copy() # Get original data again
399
+ error_placeholder['query_rephrased_status'] = f'failed_retry_exception_{type(exc).__name__}'
400
+ # Store this error placeholder
401
+ retry_results_dict[original_index] = error_placeholder
402
+ pbar.set_postfix({"LastStatus": error_placeholder['query_rephrased_status']}, refresh=True)
403
+
404
+ finally:
405
+ processed_count_in_retry += 1
406
+ pbar.update(1)
407
+ # Optional intermediate save logic (maybe save every N retries)
408
+ # Could save the *entire* potentially partially updated list, but might be slow.
409
+ # if processed_count_in_retry % BATCH_SAVE_SIZE == 0:
410
+ # logging.info(f"Processed {processed_count_in_retry} retries, updating intermediate state...")
411
+ # # Update the main list with results gathered so far
412
+ # for idx, updated_item in retry_results_dict.items():
413
+ # results_list[idx] = updated_item
414
+ # # Clear the temporary dict after updating
415
+ # retry_results_dict.clear()
416
+ # # Save the whole list (potentially slow)
417
+ # save_final_dataset(results_list, FINAL_OUTPUT_PATH + "_interim")
418
+
419
+
420
+ except KeyboardInterrupt:
421
+ logging.warning("\nCtrl+C detected during retry! Attempting to save progress...")
422
+
423
+ except Exception as e:
424
+ logging.error(f"An unexpected error occurred during the retry loop: {e}", exc_info=True)
425
+
426
+ finally:
427
+ if 'pbar' in locals():
428
+ pbar.close()
429
+
430
+ # --- Update the main results list with all completed retries ---
431
+ logging.info("Updating main results list with completed retry attempts...")
432
+ update_count = 0
433
+ for idx, updated_item in retry_results_dict.items():
434
+ if idx < len(results_list):
435
+ results_list[idx] = updated_item
436
+ update_count += 1
437
+ else:
438
+ logging.error(f"Index {idx} from retry results is out of bounds for results_list (size {len(results_list)}). Skipping update.")
439
+
440
+ logging.info(f"Applied updates for {update_count} retried items.")
441
+
442
+ # --- Final Save ---
443
+ logging.info(f"Attempting to save the final updated dataset to: {FINAL_OUTPUT_PATH}")
444
+ if save_final_dataset(results_list, FINAL_OUTPUT_PATH):
445
+ logging.info("Final dataset saved successfully.")
446
+ # Optional: Suggest deleting the old intermediate path if successful
447
+ # logging.info(f"You may now safely remove the intermediate directory: {PROCESSED_DATA_PATH}")
448
+ else:
449
+ logging.error(">>> FINAL SAVE FAILED! <<<")
450
+ logging.error(f"Check the logs. The latest state might be in memory or a fallback JSON file if created.")
451
+
452
+
453
+ # --- Final Verification (Optional) ---
454
+ logging.info("------------------------------------------------------")
455
+ logging.info("Verification: Loading final saved dataset for status check...")
456
+ try:
457
+ final_reloaded_dataset = Dataset.load_from_disk(FINAL_OUTPUT_PATH)
458
+ logging.info(f"Successfully reloaded final dataset with {len(final_reloaded_dataset)} examples from {FINAL_OUTPUT_PATH}.")
459
+ status_counts = {}
460
+ for ex in final_reloaded_dataset:
461
+ status = ex.get('query_rephrased_status', 'unknown_status_field')
462
+ status_counts[status] = status_counts.get(status, 0) + 1
463
+
464
+ logging.info("Status counts in the final saved file:")
465
+ for status, count in sorted(status_counts.items()):
466
+ logging.info(f" - {status}: {count}")
467
+
468
+ # Highlight remaining failures
469
+ remaining_failures = status_counts.get('failed_llm_retry', 0) + \
470
+ status_counts.get('failed_retry_exception', 0) + \
471
+ status_counts.get('failed_llm_call', 0) # Include original failures if not retried/still failing
472
+
473
+ if remaining_failures > 0:
474
+ logging.warning(f"Found {remaining_failures} examples still marked as failed after retry attempts.")
475
+ else:
476
+ logging.info("All identified failures appear to have been successfully retried or were not retried.")
477
+
478
+ except FileNotFoundError:
479
+ logging.error(f"Verification failed: Final saved dataset not found at {FINAL_OUTPUT_PATH}.")
480
+ except Exception as e:
481
+ logging.error(f"Failed to reload or verify final dataset from {FINAL_OUTPUT_PATH}: {e}", exc_info=True)
482
+
483
+
484
+ end_time = time.time()
485
+ logging.info("------------------------------------------------------")
486
+ logging.info(f"Retry script finished in {end_time - start_time:.2f} seconds.")
487
+ logging.info("======================================================")
r1-a/dataset/final_tts.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --- ENVIRONMENT VARIABLE CONTROL ---
2
+ import os
3
+ import sys
4
+ import random
5
+ import torch
6
+ import torchaudio
7
+ from datasets import load_dataset, Dataset, load_from_disk, Features, Value, Audio
8
+ from tqdm import tqdm
9
+ import time
10
+ import logging
11
+ import json
12
+ import math
13
+ import pathlib
14
+ import re
15
+ import unicodedata
16
+
17
+ # --- Read Environment Variables ---
18
+ try:
19
+ # GPU ID for this specific run (0, 1, 2, or 3)
20
+ PROCESSING_GPU_ID = 3
21
+ # Total number of parallel runs (should be 4)
22
+ TOTAL_PROCESSING_NODES = 4
23
+ except ValueError:
24
+ print("Error: PROCESSING_GPU_ID and TOTAL_PROCESSING_NODES env vars must be integers.")
25
+ sys.exit(1)
26
+
27
+ if not 0 <= PROCESSING_GPU_ID < TOTAL_PROCESSING_NODES:
28
+ print(f"Error: PROCESSING_GPU_ID ({PROCESSING_GPU_ID}) must be between 0 and {TOTAL_PROCESSING_NODES - 1}.")
29
+ sys.exit(1)
30
+
31
+ print(f"--- Starting Run for Shard {PROCESSING_GPU_ID + 1} / {TOTAL_PROCESSING_NODES} ---")
32
+ print(f"--- Targetting GPU Index (physical): {PROCESSING_GPU_ID} ---")
33
+
34
+ # --- SET VISIBLE CUDA DEVICE *BEFORE* TORCH IMPORT THAT USES CUDA ---
35
+ # This makes the chosen GPU appear as 'cuda:0' to this script instance
36
+ os.environ["CUDA_VISIBLE_DEVICES"] = str(PROCESSING_GPU_ID)
37
+
38
+ # Check CUDA availability *after* setting visibility
39
+ if not torch.cuda.is_available():
40
+ print(f"ERROR: CUDA device {PROCESSING_GPU_ID} is not available after setting CUDA_VISIBLE_DEVICES.")
41
+ sys.exit(1)
42
+ else:
43
+ # PyTorch now sees the selected GPU as cuda:0
44
+ effective_device = torch.device("cuda:0")
45
+ try:
46
+ print(f"Script process {os.getpid()} successfully assigned to specific GPU: {torch.cuda.get_device_name(0)} (Original Index {PROCESSING_GPU_ID})")
47
+ except Exception as e:
48
+ print(f"Warning: Could not get device name, but CUDA is available. Error: {e}")
49
+ print(f"Script process {os.getpid()} assigned to specific GPU index {PROCESSING_GPU_ID}")
50
+
51
+
52
+ # --- Add CosyVoice Path ---
53
+ COSYVOICE_PATH = '/home/chenyifu/CosyVoice' # <-- Your path
54
+ if COSYVOICE_PATH not in sys.path:
55
+ sys.path.append(COSYVOICE_PATH)
56
+
57
+ # Import CosyVoice
58
+ try:
59
+ from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
60
+ from cosyvoice.utils.file_utils import load_wav
61
+ except ImportError as e:
62
+ print(f"Error importing CosyVoice: {e}")
63
+ sys.exit(1)
64
+
65
+ # Setup basic logging for this instance
66
+ logging.basicConfig(level=logging.INFO, format=f'%(asctime)s - %(levelname)s - [GPU-{PROCESSING_GPU_ID}] %(message)s')
67
+
68
+ # ------------------------
69
+ # 配置参数 (Configuration Parameters) - Mostly unchanged
70
+ # ------------------------
71
+ # --- Input Dataset ---
72
+ INPUT_DATASET_PATH = '/home/chenyifu/audio-r1/r1-a/dataset/prompt_only'
73
+ TEXT_FIELD_FOR_TTS = "question_text"
74
+ AUDIO_PATH_FIELD = "question_audio" # Field name for the *final* aggregated dataset
75
+ ASSUMED_INPUT_SPLIT_NAME = "train"
76
+
77
+ # --- Output ---
78
+ TTS_OUTPUT_BASE_PATH = '/home/chenyifu/audio-r1/r1-a/dataset/prompt_only_fully_merged_with_audio' # <<-- SHARED output path for all runs
79
+ AUDIO_SUBFOLDER_NAME = 'audio_files' # <<-- SHARED audio subfolder
80
+
81
+ # --- Prompt Audio Settings ---
82
+ RAW_PROMPT_DATASET_PATH = "/home/chenyifu/audio-r1/r1-a/dataset/mls_eng10k"
83
+ PROMPT_DATASET_SPLIT = "train"
84
+ PROMPT_MIN_DURATION_S = 10
85
+ PROMPT_MAX_DURATION_S = 13
86
+ PROMPT_TEXT_FIELD = "transcript"
87
+ PROMPT_AUDIO_DURATION_FIELD = "audio_duration"
88
+ FILTERED_PROMPT_DATASET_PATH = f"{RAW_PROMPT_DATASET_PATH}_filtered_{PROMPT_MIN_DURATION_S}_{PROMPT_MAX_DURATION_S}s" # SHARED path
89
+
90
+ # --- TTS Settings ---
91
+ TARGET_SAMPLE_RATE = 16000 # Desired *prompt* sample rate
92
+ MAX_TTS_RETRIES = 3
93
+ RETRY_DELAY_SECONDS = 2
94
+
95
+ # --- Processing Settings ---
96
+ # TEST_SINGLE_SAMPLE = False # Not needed, shard logic handles subset
97
+ # MULTI_GPU_PROCESSING = False # Not using mp.spawn
98
+
99
+ # ------------------------
100
+ # 辅助函数 (Helper Functions) - Unchanged from previous multi-GPU capable script
101
+ # ------------------------
102
+ def preprocess_text(text):
103
+ if not isinstance(text, str): return ""
104
+ text = unicodedata.normalize('NFKC', text)
105
+ text = re.sub(r'[—–―‐‑⁃﹣-]', ' ', text)
106
+ text = text.replace('\u00AD', '').replace('\u200B', '')
107
+ text = re.sub(r'\s+', ' ', text).strip()
108
+ if text and text[-1] not in ['.', '?', '!']: text += '.'
109
+ return text
110
+
111
+ def filter_prompt_logic(example):
112
+ if PROMPT_AUDIO_DURATION_FIELD in example and isinstance(example[PROMPT_AUDIO_DURATION_FIELD], (int, float)):
113
+ duration = example[PROMPT_AUDIO_DURATION_FIELD]; return PROMPT_MIN_DURATION_S <= duration <= PROMPT_MAX_DURATION_S
114
+ else:
115
+ try:
116
+ audio_info = example['audio']; samplerate = audio_info['sampling_rate']; duration = len(audio_info['array']) / samplerate
117
+ return PROMPT_MIN_DURATION_S <= duration <= PROMPT_MAX_DURATION_S
118
+ except: return False
119
+
120
+ def get_random_valid_prompt(filtered_prompt_dataset, target_sample_rate=TARGET_SAMPLE_RATE):
121
+ if not filtered_prompt_dataset or len(filtered_prompt_dataset) == 0: raise ValueError("Filtered prompt dataset empty!")
122
+ idx = random.randint(0, len(filtered_prompt_dataset) - 1)
123
+ try: sample = filtered_prompt_dataset[idx]
124
+ except IndexError: sample = filtered_prompt_dataset[0]
125
+ audio_info = sample['audio']; prompt_text = sample[PROMPT_TEXT_FIELD]
126
+ if isinstance(audio_info, dict) and 'array' in audio_info: waveform = torch.tensor(audio_info['array'], dtype=torch.float32); sr = audio_info['sampling_rate']
127
+ elif isinstance(audio_info, str) or (isinstance(audio_info, dict) and 'path' in audio_info):
128
+ path = audio_info if isinstance(audio_info, str) else audio_info['path']; waveform, sr = torchaudio.load(path)
129
+ else: raise TypeError("Unknown prompt audio format")
130
+ if not prompt_text or waveform.numel() == 0: return get_random_valid_prompt(filtered_prompt_dataset, target_sample_rate)
131
+ if sr != target_sample_rate:
132
+ if waveform.dim() > 1 and waveform.shape[0] > 1: waveform = waveform.mean(dim=0, keepdim=True)
133
+ elif waveform.dim() == 1: waveform = waveform.unsqueeze(0)
134
+ resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sample_rate); waveform = resampler(waveform)
135
+ if waveform.dim()==1: waveform = waveform.unsqueeze(0)
136
+ elif waveform.shape[0] > 1 : waveform = waveform.mean(dim=0, keepdim=True)
137
+ return waveform.cpu(), prompt_text
138
+
139
+ def text_to_audio(text_to_convert, cosyvoice, filtered_prompt_dataset, target_sample_rate, stream=False, max_retries=MAX_TTS_RETRIES):
140
+ cleaned_text = preprocess_text(text_to_convert)
141
+ if not cleaned_text: logging.warning(f"Empty text after cleaning: '{text_to_convert[:60]}...'"); return None
142
+ last_exception = None
143
+ for attempt in range(max_retries):
144
+ try:
145
+ prompt_speech, prompt_text = get_random_valid_prompt(filtered_prompt_dataset, target_sample_rate)
146
+ all_speech = []
147
+ inference_generator = cosyvoice.inference_zero_shot(cleaned_text, prompt_text, prompt_speech, stream=stream)
148
+ for i, chunk in enumerate(inference_generator):
149
+ if 'tts_speech' in chunk and chunk['tts_speech'] is not None: all_speech.append(chunk['tts_speech'])
150
+ if not all_speech: raise ValueError(f"TTS produced no audio chunks. Cleaned: '{cleaned_text[:60]}...'")
151
+ combined_speech = torch.cat(all_speech, dim=-1); actual_sample_rate = cosyvoice.sample_rate
152
+ return {'audio_tensor': combined_speech, 'sample_rate': actual_sample_rate}
153
+ except Exception as e:
154
+ last_exception = e; logging.error(f"TTS Error Attempt {attempt + 1}: {e}", exc_info=False)
155
+ if torch.cuda.is_available(): torch.cuda.empty_cache()
156
+ if attempt < max_retries - 1: time.sleep(RETRY_DELAY_SECONDS * (1.5 ** attempt) + random.uniform(0.5, 1.5))
157
+ else: logging.error(f"All TTS retries failed for: '{cleaned_text[:60]}...'")
158
+ return None
159
+
160
+ # -----------------------------
161
+ # --- Main Execution Logic ----
162
+ # -----------------------------
163
+ if __name__ == "__main__":
164
+ # --- Load or Create Filtered Prompt Dataset (Safe for concurrent runs, only first one creates) ---
165
+ # Add a small delay + check to mitigate potential race condition on creation
166
+ if not os.path.exists(FILTERED_PROMPT_DATASET_PATH):
167
+ time.sleep(random.uniform(0, 2)) # Small random delay
168
+ if not os.path.exists(FILTERED_PROMPT_DATASET_PATH): # Double check
169
+ logging.info(f"Filtered prompt dataset not found. Attempting creation...")
170
+ try:
171
+ prompt_dataset_raw = load_dataset(RAW_PROMPT_DATASET_PATH, split=PROMPT_DATASET_SPLIT)
172
+ if 'audio' in prompt_dataset_raw.features and not isinstance(prompt_dataset_raw.features['audio'], Audio):
173
+ prompt_dataset_raw = prompt_dataset_raw.cast_column("audio", Audio(decode=True))
174
+ filtered_ds = prompt_dataset_raw.filter(filter_prompt_logic, num_proc=max(1, os.cpu_count() // 2))
175
+ if len(filtered_ds) == 0: raise ValueError("No prompts left after filtering.")
176
+ cols = ['audio', PROMPT_TEXT_FIELD]
177
+ if PROMPT_AUDIO_DURATION_FIELD in prompt_dataset_raw.column_names: cols.append(PROMPT_AUDIO_DURATION_FIELD)
178
+ filtered_ds = filtered_ds.select_columns(cols)
179
+ filtered_ds.save_to_disk(FILTERED_PROMPT_DATASET_PATH)
180
+ logging.info(f"Filtered prompt dataset CREATED and saved to: {FILTERED_PROMPT_DATASET_PATH}")
181
+ except Exception as e:
182
+ logging.error(f"FATAL: Failed to create filtered prompt dataset: {e}", exc_info=True)
183
+ # If creation fails, other processes might also fail loading. Check logs.
184
+ sys.exit(1)
185
+ else:
186
+ logging.info(f"Filtered prompt dataset appeared while waiting. Proceeding.")
187
+
188
+ try:
189
+ logging.info(f"Loading filtered prompt dataset from: {FILTERED_PROMPT_DATASET_PATH}")
190
+ filtered_prompt_dataset = load_from_disk(FILTERED_PROMPT_DATASET_PATH)
191
+ logging.info(f"Loaded {len(filtered_prompt_dataset)} filtered prompts.")
192
+ except Exception as e:
193
+ logging.error(f"FATAL: Failed to load filtered prompt dataset: {e}", exc_info=True)
194
+ sys.exit(1)
195
+
196
+
197
+ # --- Initialize TTS Model (on the assigned GPU 'cuda:0') ---
198
+ logging.info("Initializing CosyVoice model for this process...")
199
+ try:
200
+ # Model will initialize on cuda:0, which corresponds to the selected physical GPU
201
+ cosyvoice = CosyVoice2(
202
+ f'{COSYVOICE_PATH}/pretrained_models/CosyVoice2-0.5B',
203
+ load_jit=True, load_trt=False, fp16=False
204
+ )
205
+ model_output_sr = cosyvoice.sample_rate
206
+ logging.info(f"CosyVoice initialized. Model output SR: {model_output_sr}")
207
+ except Exception as e:
208
+ logging.error(f"Error initializing CosyVoice2 model: {e}", exc_info=True)
209
+ sys.exit(1)
210
+
211
+
212
+ # --- Load Main Input Dataset ---
213
+ logging.info(f"Loading main input dataset from: {INPUT_DATASET_PATH}")
214
+ try:
215
+ input_dataset_full = load_from_disk(INPUT_DATASET_PATH)
216
+ dataset_size = len(input_dataset_full)
217
+ logging.info(f"Loaded main dataset with {dataset_size} examples.")
218
+
219
+ # --- Add original index column ---
220
+ logging.info("Adding original index...")
221
+ def add_index(example, idx): example['original_index'] = idx; return example
222
+ input_dataset_with_indices = input_dataset_full.map(add_index, with_indices=True, num_proc=max(1, os.cpu_count() // 2))
223
+ logging.info("Original index added.")
224
+
225
+ # --- Shard the dataset for this specific run ---
226
+ logging.info(f"Selecting shard {PROCESSING_GPU_ID + 1} / {TOTAL_PROCESSING_NODES} for processing...")
227
+ dataset_shard = input_dataset_with_indices.shard(
228
+ num_shards=TOTAL_PROCESSING_NODES,
229
+ index=PROCESSING_GPU_ID,
230
+ contiguous=True # Potentially faster access
231
+ )
232
+ shard_size = len(dataset_shard)
233
+ logging.info(f"This instance will process {shard_size} samples.")
234
+
235
+ except Exception as e:
236
+ logging.error(f"Error loading or sharding main input dataset: {e}", exc_info=True)
237
+ sys.exit(1)
238
+
239
+ # --- Define Shared Output Audio Directory ---
240
+ # All instances write to the same place
241
+ split_output_dir = os.path.join(TTS_OUTPUT_BASE_PATH, ASSUMED_INPUT_SPLIT_NAME)
242
+ split_audio_dir = os.path.join(split_output_dir, AUDIO_SUBFOLDER_NAME)
243
+ os.makedirs(split_audio_dir, exist_ok=True) # Ensure it exists
244
+
245
+ # --- Process the Assigned Shard ---
246
+ logging.info(f"Starting TTS processing for shard {PROCESSING_GPU_ID + 1}...")
247
+ pbar = tqdm(total=shard_size, desc=f"GPU-{PROCESSING_GPU_ID} TTS", ncols=100)
248
+ for sample in dataset_shard:
249
+ try:
250
+ original_idx = sample['original_index']
251
+ text_to_convert = sample.get(TEXT_FIELD_FOR_TTS)
252
+
253
+ if not text_to_convert or not isinstance(text_to_convert, str) or not text_to_convert.strip():
254
+ logging.warning(f"Skipping original index {original_idx}: missing/invalid text.")
255
+ pbar.update(1)
256
+ continue
257
+
258
+ # Define audio path using original index -> SHARED audio dir
259
+ audio_filename = f"query_{original_idx}.wav"
260
+ absolute_audio_path = os.path.join(split_audio_dir, audio_filename)
261
+
262
+ # Check if audio already exists (supports resuming any run)
263
+ if os.path.exists(absolute_audio_path):
264
+ logging.debug(f"Audio exists for original index {original_idx}, skipping.")
265
+ pbar.update(1)
266
+ continue
267
+
268
+ # Perform TTS
269
+ tts_result = text_to_audio(
270
+ text_to_convert,
271
+ cosyvoice,
272
+ filtered_prompt_dataset,
273
+ TARGET_SAMPLE_RATE, # Prompt SR
274
+ stream=False
275
+ )
276
+
277
+ if tts_result is not None:
278
+ audio_tensor = tts_result['audio_tensor'] # GPU tensor
279
+ output_sample_rate = tts_result['sample_rate'] # Use model's actual SR
280
+
281
+ try:
282
+ audio_tensor_cpu = audio_tensor.detach().cpu().to(torch.float32)
283
+ if audio_tensor_cpu.dim() == 1: audio_tensor_cpu = audio_tensor_cpu.unsqueeze(0)
284
+ elif audio_tensor_cpu.dim() > 2: audio_tensor_cpu = audio_tensor_cpu.view(1, -1)
285
+
286
+ # Save to the SHARED audio directory
287
+ torchaudio.save(absolute_audio_path, audio_tensor_cpu, output_sample_rate)
288
+ logging.debug(f"Saved audio for original index {original_idx}")
289
+
290
+ del audio_tensor, audio_tensor_cpu
291
+ if torch.cuda.is_available(): torch.cuda.empty_cache()
292
+
293
+ except Exception as e:
294
+ logging.error(f"Failed to save audio for original index {original_idx}: {e}")
295
+ if 'audio_tensor' in locals(): del audio_tensor
296
+ if 'audio_tensor_cpu' in locals(): del audio_tensor_cpu
297
+ if torch.cuda.is_available(): torch.cuda.empty_cache()
298
+ else:
299
+ logging.warning(f"TTS failed for original index {original_idx} after retries.")
300
+
301
+ except Exception as e:
302
+ logging.error(f"Unexpected error processing sample for original index {sample.get('original_index', 'UNKNOWN')}: {e}", exc_info=True)
303
+ if torch.cuda.is_available(): torch.cuda.empty_cache()
304
+
305
+ finally:
306
+ pbar.update(1)
307
+
308
+ pbar.close()
309
+ logging.info(f"--- Finished processing shard {PROCESSING_GPU_ID + 1} / {TOTAL_PROCESSING_NODES} ---")
310
+
311
+ # --- End of Script ---
312
+ print("="*30)
313
+ print(f"Run for GPU {PROCESSING_GPU_ID} completed.")
314
+ print(f"Audio files (if successful) saved in: {split_audio_dir}")
315
+ print("IMPORTANT: Run the aggregation script AFTER all 4 runs are finished to create the final dataset object.")
316
+ print("="*30)
r1-a/dataset/gsm8k.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import torch
4
+ import torchaudio
5
+ from datasets import load_dataset, Dataset
6
+ import sys
7
+ from tqdm import tqdm
8
+
9
+ sys.path.append('/root/autodl-tmp/CosyVoice')
10
+ from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
11
+ from cosyvoice.utils.file_utils import load_wav
12
+
13
+ # 配置参数
14
+ COMMON_VOICE_LANGUAGE = "en"
15
+ DATASET_NAME = "gsm8k" # 使用 gsm8k 数据集
16
+ OUTPUT_DATASET_PATH = './gsm8k_with_audio'
17
+ SAMPLE_RATE = 16000
18
+
19
+ # --- 辅助函数 ---
20
+
21
+ def get_random_prompt(common_voice_dataset, sample_rate=SAMPLE_RATE):
22
+ """
23
+ 从 Common Voice 数据集中随机抽取一条语音及对应文本作为 prompt。
24
+ """
25
+ idx = random.randint(0, len(common_voice_dataset) - 1)
26
+ sample = common_voice_dataset.select([idx])[0]
27
+ audio = sample['audio']
28
+ waveform = torch.tensor(audio['array'], dtype=torch.float32)
29
+ sr = audio['sampling_rate']
30
+ if sr != sample_rate:
31
+ resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate)
32
+ waveform = resampler(waveform)
33
+ return waveform.unsqueeze(0), sample['raw_text']
34
+
35
+ def text_to_audio(query_text, cosyvoice, common_voice_dataset, stream=False):
36
+ """
37
+ 利用 CosyVoice2 模型将输入文本转换为语音,采用随机 prompt 进行 zero-shot 推理。
38
+ """
39
+ try:
40
+ prompt_speech, prompt_text = get_random_prompt(common_voice_dataset, sample_rate=SAMPLE_RATE)
41
+ # 可选:保存 prompt.wav 进行调试
42
+ # torchaudio.save('prompt.wav', prompt_speech, SAMPLE_RATE)
43
+ all_speech = []
44
+ for i, j in enumerate(cosyvoice.inference_zero_shot(
45
+ query_text,
46
+ prompt_text,
47
+ prompt_speech,
48
+ stream=stream,
49
+ text_frontend=False
50
+ )):
51
+ all_speech.append(j['tts_speech'])
52
+ # 合并所有生成的语音片段为一个长 tensor
53
+ combined_speech = torch.cat(all_speech, dim=-1)
54
+ sample_rate_val = cosyvoice.sample_rate
55
+ return {'audio_tensor': combined_speech, 'sample_rate': sample_rate_val}
56
+ except Exception as e:
57
+ print(f"Error converting text to audio: {e}")
58
+ return None
59
+
60
+ def process_example(example, cosyvoice, common_voice_dataset, sample_rate=SAMPLE_RATE):
61
+ """
62
+ 针对 gsm8k 数据集中的单个样本进行 TTS 处理。
63
+ 假设 gsm8k 数据集中的问题文本字段为 'question',
64
+ 答案字段为 'answer'。
65
+ """
66
+ query = example['question']
67
+ audio_result = text_to_audio(query, cosyvoice, common_voice_dataset, stream=False)
68
+ if audio_result is not None:
69
+ # 返回生成的音频 tensor 及采样率
70
+ return {
71
+ 'audio_tensor': audio_result['audio_tensor'],
72
+ 'sample_rate': audio_result['sample_rate']
73
+ }
74
+ else:
75
+ return None
76
+
77
+ # --- 数据加载与模型初始化 ---
78
+
79
+ print("Loading Common Voice dataset...")
80
+ common_voice = load_dataset("facebook/voxpopuli", "en", split='train')
81
+ print(f"Total Common Voice {COMMON_VOICE_LANGUAGE} samples: {len(common_voice)}")
82
+
83
+ print("Initializing CosyVoice2 model...")
84
+ cosyvoice = CosyVoice2(
85
+ '/root/autodl-tmp/CosyVoice/pretrained_models/CosyVoice2-0.5B', # 替换为实际的模型路径
86
+ load_jit=True,
87
+ load_trt=False,
88
+ fp16=False
89
+ )
90
+
91
+ print("Loading GSM8K dataset...")
92
+ dataset = load_dataset("openai/gsm8k", 'main')
93
+
94
+ # 确保输出总目录存在
95
+ os.makedirs(OUTPUT_DATASET_PATH, exist_ok=True)
96
+
97
+ # --- 主处理循环 ---
98
+ # 对每个 split 分别处理,每个样本处理后保存 .wav 文件和记录最终数据集信息
99
+ final_dataset_dict = {} # 用于保存最终数据集的每个 split
100
+
101
+ for split_name, split_dataset in dataset.items():
102
+ print(f"Processing split: {split_name} with {len(split_dataset)} examples")
103
+ split_output_dir = os.path.join(OUTPUT_DATASET_PATH, split_name)
104
+ os.makedirs(split_output_dir, exist_ok=True)
105
+
106
+ # 用于断点续转的进度记录
107
+ progress_file = os.path.join(split_output_dir, "progress.txt")
108
+ start_index = 0
109
+ if os.path.exists(progress_file):
110
+ try:
111
+ with open(progress_file, "r") as f:
112
+ start_index = int(f.read().strip())
113
+ print(f"Resuming split '{split_name}' from sample index {start_index}")
114
+ except Exception as e:
115
+ print(f"读取进度文件失败:{e}")
116
+
117
+ final_samples = [] # 用于存储最终数据集样本信息
118
+ for i in tqdm(range(len(split_dataset)), desc=f"Processing {split_name}"):
119
+ if i < start_index:
120
+ # 如果样本已处理,则加载对应的 wav 文件路径(假设之前已经生成)并加入最终数据集
121
+ sample = split_dataset[i]
122
+ wav_path = os.path.join(split_output_dir, f"{i}.wav")
123
+ # 仅当文件存在时才加入最终数据集
124
+ if os.path.exists(wav_path):
125
+ final_samples.append({
126
+ "question_text": sample["question"],
127
+ "answer": sample["answer"],
128
+ "audio_filepath": wav_path
129
+ })
130
+ continue
131
+
132
+ sample = split_dataset[i]
133
+ # 处理 TTS 转换
134
+ result = process_example(sample, cosyvoice, common_voice, sample_rate=SAMPLE_RATE)
135
+
136
+ if result is not None:
137
+ # 确保 audio tensor shape 为 (channels, samples)
138
+ audio_tensor = result['audio_tensor']
139
+ if audio_tensor.dim() == 1:
140
+ audio_tensor = audio_tensor.unsqueeze(0)
141
+ sample_rate_val = result['sample_rate']
142
+ output_wav_path = os.path.join(split_output_dir, f"{i}.wav")
143
+ try:
144
+ torchaudio.save(output_wav_path, audio_tensor, sample_rate_val)
145
+ except Exception as e:
146
+ print(f"Failed to save wav for sample {i}: {e}")
147
+ continue
148
+
149
+ # 将转换后的样本信息保存到最终数据集中
150
+ final_samples.append({
151
+ "question_text": sample["question"],
152
+ "answer": sample["answer"],
153
+ "audio_filepath": output_wav_path
154
+ })
155
+ else:
156
+ print(f"Sample {i} processing failed, no audio generated.")
157
+
158
+ # 更新进度记录
159
+ with open(progress_file, "w") as f:
160
+ f.write(str(i + 1))
161
+
162
+ # 将当前 split 的最终数据集保存为 Hugging Face Dataset,并存盘
163
+ final_dataset_obj = Dataset.from_list(final_samples)
164
+ final_dataset_save_path = os.path.join(split_output_dir, "final_dataset")
165
+ final_dataset_obj.save_to_disk(final_dataset_save_path)
166
+ print(f"Finished processing split: {split_name} with {len(final_samples)} final samples.")
167
+ final_dataset_dict[split_name] = final_dataset_obj
168
+
169
+ print("所有分割处理完毕,最终数据集已保存。")
r1-a/dataset/gsm8k_with_audio/test/299.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:faf7e5efc632e11631a878b477523f9a010ad5e3bba9db05f3bc20cafadab95e
3
+ size 2277200
r1-a/dataset/gsm8k_with_audio/test/301.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5010c552eb3c63b659ef8d71af813db7706b55c302f9c9e1c17a831c8bc896a4
3
+ size 2346320
r1-a/dataset/gsm8k_with_audio/test/302.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:06e41bddcccc46feb3fd116d32c245ef65cf0434acda1313d020d286a947ddcd
3
+ size 917840
r1-a/dataset/gsm8k_with_audio/test/314.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:026282fcebd0c6ba5bef04291128bbc5dd82eadeb2c020faa0a53135477a7e19
3
+ size 1332560
r1-a/dataset/gsm8k_with_audio/test/316.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:83a29e59fc73ac1eb3b49033da596e47141d5911ceae61effc736b4b307a17de
3
+ size 1505360
r1-a/dataset/gsm8k_with_audio/test/350.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f5fd7f8dbaf220cd8fe66a151c056e384e07b13cf76778983bbc605043e43d4d
3
+ size 1946960
r1-a/dataset/gsm8k_with_audio/test/358.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9a9416095b31c0d238bb1321e6108733973fa651cfcee1fca8f898d18470a89b
3
+ size 1363280
r1-a/dataset/gsm8k_with_audio/test/359.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c1ba8571451ad8a8bce5fdf235517e0aae650b7a42d9cc763e92adc7e9c949de
3
+ size 1532240
r1-a/dataset/gsm8k_with_audio/test/369.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4365c6f3f261264849987b3a13a1d552c1004977c0afea259b9e29dad8fee08c
3
+ size 1866320
r1-a/dataset/gsm8k_with_audio/test/372.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a22d041239c1485f42ac79c20fd1ae2a9174fef598e36030c516c882663eae3c
3
+ size 1251920
r1-a/dataset/gsm8k_with_audio/test/376.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:65160b2855b4f25f29a7e0f4b5da970cba1e60ae1f1e637e5adfd7495c896f79
3
+ size 1198160
r1-a/dataset/gsm8k_with_audio/test/385.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3bbcea803a885d3d5cc8b17f23837a1820f119f9564006023f8781b96a2114f4
3
+ size 2165840
r1-a/dataset/gsm8k_with_audio/test/394.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ec4937195a9d64a026cd352a09943a56ee64344fedda73cb792d4d12b466fe3f
3
+ size 1344080
r1-a/dataset/gsm8k_with_audio/test/395.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7f84925ad72d43ff5b7bf24aa71472f08bc0ff5b43a09ed92929848451401ff0
3
+ size 1804880
r1-a/dataset/gsm8k_with_audio/test/397.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cc7c4027e38fd2090b026f8112151fd564d5a80df9ff230dbac69cb082f60c59
3
+ size 1056080
r1-a/dataset/gsm8k_with_audio/test/400.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cf0dace278147001a9b14f5701f5d206039b26438e50f19bcaf6015b46dab78e
3
+ size 867920
r1-a/dataset/gsm8k_with_audio/test/401.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a494472207f4c7f2d067dfa1ce3c19d18702b8fca5385aa2f23b7b7a9a869767
3
+ size 940880
r1-a/dataset/gsm8k_with_audio/test/447.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6f08dd8d815e539e2d6b2ea2c06320784e3f6b3f862e3d50e20a2c11ac38785b
3
+ size 1586000
r1-a/dataset/gsm8k_with_audio/test/45.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9a65adddc9e7cb0a684486badfd68a41610e2acd64e30b8e8a7da76607586092
3
+ size 2112080
r1-a/dataset/gsm8k_with_audio/test/450.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e42276c81767cb468b27f095f71e5a2765272a753ee9d69cb399998561ee3348
3
+ size 2104400
r1-a/dataset/gsm8k_with_audio/test/454.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c94820c5d8bdd41bdaf461ca950650950c0e44fa2427cb86c9ae59bfc8f781fc
3
+ size 599120
r1-a/dataset/gsm8k_with_audio/test/457.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3e3aaab17db00c2275f1ac1ebb8d2c710f4e7393488ca0a4fcdb45d57b6da2d9
3
+ size 1774160
r1-a/dataset/gsm8k_with_audio/test/458.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:813d8661f38ffba540960d4c50dd877714a196acd0d9fa1d2f9a9285ab35b40e
3
+ size 1082960
r1-a/dataset/gsm8k_with_audio/test/459.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:48f55b9dbcda81e8e98ac58c7e428de7ed472f79d4dbbf763d8f6232d53b3e00
3
+ size 2945360
r1-a/dataset/gsm8k_with_audio/test/463.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3c0710e41b29894e6e2d4519a60c00ccdba41584f0a8fa1e9c4800bf69d5b63f
3
+ size 1298000
r1-a/dataset/gsm8k_with_audio/test/465.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ccdf5bacc8e5c5c24f086471ad15811f56e355f2814963aaa2f2c94ce916da05
3
+ size 1747280
r1-a/dataset/gsm8k_with_audio/test/515.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:48ea6967d2fa10e62b5e74b695650c9cf4f5a06d4a232937627e75f35b93e526
3
+ size 1278800
r1-a/dataset/gsm8k_with_audio/test/877.wav ADDED
Binary file (38.5 kB). View file
 
r1-a/dataset/gsm8k_with_audio/test/964.wav ADDED
Binary file (88.4 kB). View file
 
r1-a/dataset/gsm8k_with_audio/test/final_dataset/dataset_info.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "citation": "",
3
+ "description": "",
4
+ "features": {
5
+ "question_text": {
6
+ "dtype": "string",
7
+ "_type": "Value"
8
+ },
9
+ "answer": {
10
+ "dtype": "string",
11
+ "_type": "Value"
12
+ },
13
+ "audio_filepath": {
14
+ "dtype": "string",
15
+ "_type": "Value"
16
+ }
17
+ },
18
+ "homepage": "",
19
+ "license": ""
20
+ }
r1-a/dataset/gsm8k_with_audio/test/final_dataset/state.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_data_files": [
3
+ {
4
+ "filename": "data-00000-of-00001.arrow"
5
+ }
6
+ ],
7
+ "_fingerprint": "a037f9a8bcdfc025",
8
+ "_format_columns": null,
9
+ "_format_kwargs": {},
10
+ "_format_type": null,
11
+ "_output_all_columns": false,
12
+ "_split": null
13
+ }
r1-a/dataset/gsm8k_with_audio/test/progress.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ 1319
r1-a/dataset/pkusafe.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from collections import defaultdict
3
+ from datasets import load_dataset, Dataset # 确保导入 Dataset
4
+ from tqdm.auto import tqdm
5
+ import traceback
6
+ import os # 需要导入 os 来创建目录
7
+
8
+ # --- 配置参数 ---
9
+ DATASET_NAME = "PKU-Alignment/PKU-SafeRLHF"
10
+ SPLIT_TO_PROCESS = "train" # 或者 'test' 等
11
+ OUTPUT_DATASET_DIR = "pku_saferlhf_filtered_unsafe_diverse_hf" # 输出目录名
12
+
13
+ # --- 脚本主逻辑 (与之前版本相同) ---
14
+
15
+ def get_true_harm_categories(harm_category_dict):
16
+ """从 harm_category 字典中提取值为 True 的键(类别名称)"""
17
+ if not isinstance(harm_category_dict, dict):
18
+ return []
19
+ return [category for category, is_present in harm_category_dict.items() if is_present]
20
+
21
+ def filter_pku_saferlhf_detailed(dataset_name: str, split: str):
22
+ """
23
+ 加载、过滤 PKU-SafeRLHF 数据集,确保包含不安全回答,提取涉及的 harm category 名称,
24
+ 并尽可能覆盖所有原始存在的 harm category 名称。
25
+ """
26
+ print(f"加载数据集: {dataset_name}, split: {split}...")
27
+ try:
28
+ dataset = load_dataset(dataset_name, split=split)
29
+ print("数据集加载完成。")
30
+ except Exception as e:
31
+ print(f"错误:无法加载数据集 {dataset_name} (split: {split})。请检查名称和可用 split。")
32
+ print(f"错误详情: {e}")
33
+ return None, set()
34
+
35
+ # --- 步骤 1: 识别原始数据集中所有出现过的 harm_category 名称 ---
36
+ print("正在识别所有原始 harm_category 名称...")
37
+ all_harm_category_names_in_split = set()
38
+ scan_errors = 0
39
+ for example in tqdm(dataset, desc="扫描 harm_category"):
40
+ try:
41
+ for harm_dict_key in ["response_0_harm_category", "response_1_harm_category"]:
42
+ harm_dict = example.get(harm_dict_key)
43
+ if isinstance(harm_dict, dict):
44
+ true_categories = get_true_harm_categories(harm_dict)
45
+ all_harm_category_names_in_split.update(true_categories)
46
+ except Exception as e:
47
+ scan_errors += 1
48
+ continue
49
+
50
+ if scan_errors > 0:
51
+ print(f"警告:在扫描 harm_category 阶段,有 {scan_errors} 个样本处理出错。")
52
+ if not all_harm_category_names_in_split:
53
+ print("错误:未能从数据集中提取任何 harm_category 名称。请检查数据集结构或内容。")
54
+ return None, set()
55
+ print(f"原始数据集中找到 {len(all_harm_category_names_in_split)} 个唯一的 harm_category 名称。")
56
+
57
+ # --- 步骤 2: 筛选至少有一个不安全回答的样本,并提取涉及的类别 ---
58
+ print("正在筛选包含不安全回答的样本并提取涉及的类别...")
59
+ filtered_data_list = [] # 改名为 list 以区分最终的 Dataset 对象
60
+ covered_categories_in_filtered = set()
61
+ processed_count = 0
62
+ filtered_count = 0
63
+ filter_errors = 0
64
+
65
+ for example in tqdm(dataset, desc="筛选不安全样本"):
66
+ processed_count += 1
67
+ try:
68
+ prompt = example.get("prompt")
69
+ response_0 = example.get("response_0")
70
+ response_1 = example.get("response_1")
71
+ is_safe_0 = example.get("is_response_0_safe")
72
+ is_safe_1 = example.get("is_response_1_safe")
73
+ response_0_harm_dict = example.get("response_0_harm_category")
74
+ response_1_harm_dict = example.get("response_1_harm_category")
75
+
76
+ if not all([prompt, response_0, response_1,
77
+ is_safe_0 is not None, is_safe_1 is not None]):
78
+ continue
79
+
80
+ if not is_safe_0 or not is_safe_1:
81
+ involved_harm_categories = set()
82
+
83
+ if not is_safe_0:
84
+ categories_0 = get_true_harm_categories(response_0_harm_dict)
85
+ involved_harm_categories.update(categories_0)
86
+
87
+ if not is_safe_1:
88
+ categories_1 = get_true_harm_categories(response_1_harm_dict)
89
+ involved_harm_categories.update(categories_1)
90
+
91
+ # (可选过滤逻辑)
92
+ # if not involved_harm_categories:
93
+ # continue
94
+
95
+ filtered_sample = {
96
+ "prompt": prompt,
97
+ "response_0": response_0,
98
+ "response_1": response_1,
99
+ "is_safe_0": is_safe_0,
100
+ "is_safe_1": is_safe_1,
101
+ # **** 注意:Dataset 对象对于 list of strings 的支持更好 ****
102
+ "involved_harm_categories": sorted(list(involved_harm_categories)),
103
+ "better_response_id": example.get("better_response_id"),
104
+ "safer_response_id": example.get("safer_response_id"),
105
+ # 可以根据需要添加其他字段
106
+ }
107
+ filtered_data_list.append(filtered_sample)
108
+ covered_categories_in_filtered.update(involved_harm_categories)
109
+ filtered_count += 1
110
+
111
+ except Exception as e:
112
+ filter_errors += 1
113
+ continue
114
+
115
+ if filter_errors > 0:
116
+ print(f"警告:在筛选阶段,有 {filter_errors} 个样本处理出错。")
117
+ print(f"筛选完成。共处理 {processed_count} 个样本,筛选出 {filtered_count} 个符合条件的样本。")
118
+
119
+ # --- 步骤 3: 检查 harm_category 覆盖情况 ---
120
+ missing_categories = all_harm_category_names_in_split - covered_categories_in_filtered
121
+ if missing_categories:
122
+ print(f"\n警告:以下 harm_category 名称存在于原始数据集中,但在筛选出的不安全样本中未能找到对应的类别: {missing_categories}")
123
+ else:
124
+ print("\n好消息!所有原始数据集中存在的 harm_category 名称都已在筛选后的数据中得到覆盖。")
125
+ print(f"最终数据集包含 {len(filtered_data_list)} 个样本,覆盖 {len(covered_categories_in_filtered)} 个 harm_category 名称。")
126
+
127
+ return filtered_data_list, covered_categories_in_filtered # 返回 list 和 set
128
+
129
+ # --- 主程序 ---
130
+ if __name__ == "__main__":
131
+ # 执行过滤
132
+ filtered_list, final_categories = filter_pku_saferlhf_detailed(DATASET_NAME, SPLIT_TO_PROCESS)
133
+
134
+ if filtered_list: # 检查返回的列表是否非空
135
+ print(f"\n将 {len(filtered_list)} 条过滤后的数据保存为 Hugging Face Dataset 格式...")
136
+ print(f"目标目录: {OUTPUT_DATASET_DIR}")
137
+
138
+ try:
139
+ # 1. 将 list of dicts 转换为 Dataset 对象
140
+ # Hugging Face 会自动推断列类型。 involved_harm_categories 是 list of strings。
141
+ hf_dataset = Dataset.from_list(filtered_list)
142
+
143
+ # 2. 保存到磁盘
144
+ if not os.path.exists(OUTPUT_DATASET_DIR):
145
+ os.makedirs(OUTPUT_DATASET_DIR)
146
+ print(f"已创建目录: {OUTPUT_DATASET_DIR}")
147
+
148
+ hf_dataset.save_to_disk(OUTPUT_DATASET_DIR)
149
+ print(f"\n数据集成功保存到目录: {OUTPUT_DATASET_DIR}")
150
+ print(f"你可以使用以下代码加载它:")
151
+ print(f"from datasets import load_from_disk")
152
+ print(f"loaded_dataset = load_from_disk('{OUTPUT_DATASET_DIR}')")
153
+ print(f"\n最终数据集覆盖的 harm_category 名称: {final_categories}")
154
+
155
+ # 打印一些样本看看 (从 Dataset 对象加载)
156
+ print("\n部分样本预览 (从保存的 Dataset 加载):")
157
+ loaded_dataset = Dataset.load_from_disk(OUTPUT_DATASET_DIR) # 加载回来验证
158
+ for i in range(min(5, len(loaded_dataset))):
159
+ sample = loaded_dataset[i]
160
+ print(f"--- 样本 {i+1} ---")
161
+ print(f"Prompt: {sample['prompt'][:150]}...")
162
+ print(f"Response 0 (is_safe={sample['is_safe_0']}): {sample['response_0'][:100]}...")
163
+ print(f"Response 1 (is_safe={sample['is_safe_1']}): {sample['response_1'][:100]}...")
164
+ print(f"Involved Harm Categories: {sample['involved_harm_categories']}")
165
+
166
+ except Exception as e:
167
+ print(f"\n错误:保存 Hugging Face Dataset 时出错: {e}")
168
+ traceback.print_exc() # 打印详细错误信息
169
+
170
+ else:
171
+ print("\n未能筛选出任何符合条件的样本,或在加载/处理数据时发生严重错误。未保存任何内容。")
r1-a/dataset/pkusafe_tts.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import torch
4
+ import torchaudio
5
+ # Import load_from_disk to load the dataset saved by your first script
6
+ from datasets import load_dataset, Dataset, load_from_disk
7
+ import sys
8
+ from tqdm import tqdm
9
+ import time # Import time for potential delays between retries
10
+
11
+ sys.path.append('/root/autodl-tmp/CosyVoice') # Make sure this path is correct for your environment
12
+ from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
13
+ from cosyvoice.utils.file_utils import load_wav
14
+
15
+ # ------------------------
16
+ # 配置参数
17
+ # ------------------------
18
+ COMMON_VOICE_LANGUAGE = "en"
19
+ # --- Path to the pre-filtered dataset saved by your FIRST script ---
20
+ FILTERED_DATASET_PATH = "pku_saferlhf_filtered_unsafe_diverse_hf" # <-- IMPORTANT: Make sure this matches the OUTPUT_DATASET_DIR from your first script
21
+ # --- Output path for THIS TTS script ---
22
+ OUTPUT_DATASET_PATH = './pku_saferlhf_filtered_with_audio' # <-- New output path for the dataset with audio
23
+ SAMPLE_RATE = 16000
24
+ MAX_TTS_RETRIES = 3 # Maximum number of TTS attempts per query
25
+ RETRY_DELAY_SECONDS = 2 # Optional delay between retries
26
+
27
+ # ------------------------
28
+ # 辅助函数 (Identical to the previous version with retry logic)
29
+ # ------------------------
30
+ def get_random_prompt(common_voice_dataset, sample_rate=SAMPLE_RATE):
31
+ """
32
+ 从 VoxPopuli (替代原 Common Voice) 数据集中随机抽取一条语音及对应文本作为 prompt。
33
+ """
34
+ idx = random.randint(0, len(common_voice_dataset) - 1)
35
+ sample = common_voice_dataset.select([idx])[0]
36
+ audio = sample['audio']
37
+ waveform = torch.tensor(audio['array'], dtype=torch.float32)
38
+ sr = audio['sampling_rate']
39
+ if sr != sample_rate:
40
+ if waveform.dim() > 1:
41
+ waveform = waveform.mean(dim=0)
42
+ resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate)
43
+ waveform = resampler(waveform)
44
+ if waveform.dim() == 1:
45
+ waveform = waveform.unsqueeze(0)
46
+ if waveform.numel() == 0 or not sample['raw_text']:
47
+ print("Warning: Got an empty prompt, trying again...")
48
+ return get_random_prompt(common_voice_dataset, sample_rate)
49
+ return waveform, sample['raw_text']
50
+
51
+ def text_to_audio(query_text, cosyvoice, common_voice_dataset, stream=False, max_retries=MAX_TTS_RETRIES):
52
+ """
53
+ 利用 CosyVoice2 模型将输入文本转换为语音,采用随机 prompt 进行零样本推理。
54
+ Includes retry logic on failure.
55
+ """
56
+ last_exception = None
57
+ for attempt in range(max_retries):
58
+ try:
59
+ prompt_speech, prompt_text = get_random_prompt(common_voice_dataset, sample_rate=SAMPLE_RATE)
60
+
61
+ all_speech = []
62
+ inference_generator = cosyvoice.inference_zero_shot(
63
+ query_text,
64
+ prompt_text,
65
+ prompt_speech,
66
+ stream=stream,
67
+ text_frontend=False
68
+ )
69
+ for i, chunk in enumerate(inference_generator):
70
+ if 'tts_speech' in chunk and chunk['tts_speech'] is not None:
71
+ all_speech.append(chunk['tts_speech'])
72
+ else:
73
+ print(f"Warning: Chunk {i} missing 'tts_speech' for text '{query_text[:60]}...'")
74
+
75
+ if not all_speech:
76
+ raise ValueError("TTS inference finished but produced no audio chunks.")
77
+
78
+ combined_speech = torch.cat(all_speech, dim=-1)
79
+ sample_rate_val = cosyvoice.sample_rate
80
+
81
+ return {
82
+ 'audio_tensor': combined_speech,
83
+ 'sample_rate': sample_rate_val
84
+ }
85
+ except Exception as e:
86
+ last_exception = e
87
+ print(f"Error converting text to audio on attempt {attempt + 1}/{max_retries}: {e}")
88
+ print(f"Text: '{query_text[:100]}...'")
89
+ print(f"Prompt Text: '{prompt_text[:100]}...'")
90
+ if attempt < max_retries - 1:
91
+ print(f"Retrying with a different prompt in {RETRY_DELAY_SECONDS}s...")
92
+ time.sleep(RETRY_DELAY_SECONDS)
93
+ else:
94
+ print(f"All {max_retries} TTS attempts failed.")
95
+
96
+ print(f"Failed to generate audio for text after {max_retries} attempts: '{query_text[:60]}...'")
97
+ print(f"Last error: {last_exception}")
98
+ return None
99
+
100
+ def process_example(example, cosyvoice, common_voice_dataset, sample_rate=SAMPLE_RATE):
101
+ """
102
+ 针对从磁盘加载的 PKU-SafeRLHF 过滤后数据集中的单个样本进行 TTS 处理。
103
+ Processes example['prompt']. <--- Changed from 'query'/'question'
104
+ """
105
+ # --- Target the 'prompt' field from the filtered PKU-SafeRLHF dataset ---
106
+ query = example.get('prompt') # <--- Use 'prompt' field
107
+ if not query or not isinstance(query, str) or query.strip() == "":
108
+ print(f"Warning: Skipping example due to missing or empty 'prompt' field: {example.keys()}") # Log keys if prompt is missing
109
+ return None
110
+
111
+ # --- Use the text_to_audio function with retry logic ---
112
+ audio_result = text_to_audio(query, cosyvoice, common_voice_dataset, stream=False)
113
+
114
+ if audio_result is not None:
115
+ audio_tensor = audio_result['audio_tensor']
116
+ if audio_tensor.dim() == 1:
117
+ audio_tensor = audio_tensor.unsqueeze(0)
118
+ elif audio_tensor.dim() > 2:
119
+ print(f"Warning: Generated audio tensor has unexpected shape {audio_tensor.shape}. Attempting to flatten.")
120
+ audio_tensor = audio_tensor.view(1, -1)
121
+
122
+ if audio_tensor.numel() == 0:
123
+ print(f"Warning: Generated audio tensor is empty for prompt: '{query[:60]}...'")
124
+ return None
125
+
126
+ return {
127
+ 'audio_tensor': audio_tensor,
128
+ 'sample_rate': audio_result['sample_rate']
129
+ }
130
+ else:
131
+ return None
132
+
133
+ # ------------------------
134
+ # 数据加载与模型初始化
135
+ # ------------------------
136
+ print("Loading VoxPopuli (as Common Voice) dataset for prompts...")
137
+ try:
138
+ common_voice = load_dataset("facebook/voxpopuli", "en", split='train')
139
+ print(f"Total VoxPopuli {COMMON_VOICE_LANGUAGE} samples: {len(common_voice)}")
140
+ if len(common_voice) == 0:
141
+ raise ValueError("VoxPopuli dataset loaded but contains no samples.")
142
+ except Exception as e:
143
+ print(f"Error loading VoxPopuli dataset: {e}")
144
+ sys.exit(1)
145
+
146
+
147
+ print("Initializing CosyVoice2 model...")
148
+ try:
149
+ cosyvoice = CosyVoice2(
150
+ '/root/autodl-tmp/CosyVoice/pretrained_models/CosyVoice2-0.5B', # Verify this path is correct
151
+ load_jit=True,
152
+ load_trt=False,
153
+ fp16=False
154
+ )
155
+ except Exception as e:
156
+ print(f"Error initializing CosyVoice2 model: {e}")
157
+ sys.exit(1)
158
+
159
+ print(f"Loading pre-filtered PKU-SafeRLHF dataset from disk: {FILTERED_DATASET_PATH}")
160
+ try:
161
+ # --- Load the dataset saved by your first script ---
162
+ filtered_dataset = load_from_disk(FILTERED_DATASET_PATH)
163
+ if not filtered_dataset:
164
+ raise ValueError(f"Dataset loaded from '{FILTERED_DATASET_PATH}' is empty or invalid.")
165
+ print(f"Successfully loaded dataset with {len(filtered_dataset)} examples.")
166
+ # --- Assume the loaded dataset corresponds to a single split (e.g., 'train') ---
167
+ # Wrap it in a dictionary to match the structure expected by the loop below
168
+ dataset_dict = {"train": filtered_dataset} # Use "train" as the key, matching the split processed by the filter script
169
+ # Alternatively, if you know the split name used in the filter script was different, use that name.
170
+ except FileNotFoundError:
171
+ print(f"Error: Pre-filtered dataset not found at '{FILTERED_DATASET_PATH}'.")
172
+ print("Please ensure the first script ran successfully and saved the data to the correct location.")
173
+ sys.exit(1)
174
+ except Exception as e:
175
+ print(f"Error loading pre-filtered dataset from '{FILTERED_DATASET_PATH}': {e}")
176
+ sys.exit(1)
177
+
178
+ # 创建输出目录
179
+ os.makedirs(OUTPUT_DATASET_PATH, exist_ok=True)
180
+
181
+ # ------------------------
182
+ # 主处理循环
183
+ # ------------------------
184
+ final_dataset_dict = {} # 存放各 split 最终处理后的数据
185
+
186
+ # Iterate through the splits defined in dataset_dict (should just be 'train' in this case)
187
+ for split_name, split_dataset in dataset_dict.items():
188
+ print(f"Processing split: {split_name} with {len(split_dataset)} examples")
189
+ split_output_dir = os.path.join(OUTPUT_DATASET_PATH, split_name)
190
+ os.makedirs(split_output_dir, exist_ok=True)
191
+
192
+ # 用于断点续跑的进度记录
193
+ progress_file = os.path.join(split_output_dir, "progress.txt")
194
+ start_index = 0
195
+ if os.path.exists(progress_file):
196
+ try:
197
+ with open(progress_file, "r") as f:
198
+ content = f.read().strip()
199
+ if content:
200
+ start_index = int(content)
201
+ print(f"Resuming split '{split_name}' from sample index {start_index}")
202
+ else:
203
+ print(f"Progress file '{progress_file}' is empty, starting from index 0.")
204
+ start_index = 0
205
+ except ValueError:
206
+ print(f"Could not parse integer from progress file '{progress_file}'. Starting from index 0.")
207
+ start_index = 0
208
+ except Exception as e:
209
+ print(f"Error reading progress file '{progress_file}': {e}. Starting from index 0.")
210
+ start_index = 0
211
+
212
+ final_samples = []
213
+
214
+ # 遍历处理每条样本
215
+ pbar = tqdm(range(start_index, len(split_dataset)), desc=f"Processing {split_name}", initial=start_index, total=len(split_dataset))
216
+ for i in pbar:
217
+ sample = split_dataset[i]
218
+ output_wav_path = os.path.join(split_output_dir, f"{i}.wav")
219
+
220
+ if os.path.exists(output_wav_path):
221
+ sample_dict = {k: sample[k] for k in sample.keys()}
222
+ sample_dict["audio_filepath"] = output_wav_path
223
+ final_samples.append(sample_dict)
224
+ with open(progress_file, "w") as f:
225
+ f.write(str(i + 1))
226
+ continue
227
+
228
+ # --- Perform TTS on the 'prompt' field ---
229
+ result = process_example(sample, cosyvoice, common_voice, sample_rate=SAMPLE_RATE)
230
+
231
+ if result is not None:
232
+ audio_tensor = result['audio_tensor']
233
+ sample_rate_val = result['sample_rate']
234
+
235
+ try:
236
+ audio_tensor_save = audio_tensor.detach().cpu().to(torch.float32)
237
+ if audio_tensor_save.dim() == 1:
238
+ audio_tensor_save = audio_tensor_save.unsqueeze(0)
239
+ elif audio_tensor_save.dim() > 2:
240
+ audio_tensor_save = audio_tensor_save.view(1, -1)
241
+
242
+ torchaudio.save(output_wav_path, audio_tensor_save, sample_rate_val)
243
+
244
+ # Preserve all original fields from the filtered dataset + add audio path
245
+ sample_dict = {k: sample[k] for k in sample.keys()}
246
+ sample_dict["audio_filepath"] = output_wav_path
247
+ final_samples.append(sample_dict)
248
+
249
+ except Exception as e:
250
+ print(f"Failed to save wav for sample {i} at {output_wav_path}: {e}")
251
+ else:
252
+ print(f"Sample {i} processing failed after retries (Prompt: '{sample.get('prompt', 'N/A')[:60]}...'), no audio generated.")
253
+
254
+ # Update progress file after processing each sample
255
+ with open(progress_file, "w") as f:
256
+ f.write(str(i + 1))
257
+
258
+ # Generate Hugging Face Dataset from the collected successful samples and save
259
+ if final_samples:
260
+ final_dataset_obj = Dataset.from_list(final_samples)
261
+ final_dataset_save_path = os.path.join(split_output_dir, "final_dataset")
262
+ try:
263
+ print(f"Saving final dataset for split '{split_name}' to {final_dataset_save_path}...")
264
+ final_dataset_obj.save_to_disk(final_dataset_save_path)
265
+ print(f"Finished processing split: {split_name}. Saved {len(final_samples)} samples with audio paths.")
266
+ final_dataset_dict[split_name] = final_dataset_obj
267
+ except Exception as e:
268
+ print(f"Error saving final dataset for split '{split_name}' to disk: {e}")
269
+ else:
270
+ print(f"Finished processing split: {split_name}. No samples were successfully processed or saved.")
271
+
272
+
273
+ print("="*30)
274
+ if final_dataset_dict:
275
+ print(f"All specified splits processed. Final datasets saved in respective subdirectories within '{OUTPUT_DATASET_PATH}'.")
276
+ print(f"Processed splits: {list(final_dataset_dict.keys())}")
277
+ else:
278
+ print(f"Processing finished, but no final datasets were generated or saved in '{OUTPUT_DATASET_PATH}'. Check logs for errors.")
279
+ print("="*30)
r1-a/dataset/retry_rewrite.py ADDED
@@ -0,0 +1,442 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import http.client
3
+ import json
4
+ import time
5
+ import random
6
+ # Import necessary types from datasets
7
+ from datasets import load_dataset, Dataset, DatasetDict, Features, Value, Sequence
8
+ from tqdm.auto import tqdm
9
+ import sys
10
+ import logging
11
+ # Removed unused imports (like socket, already used by http.client indirectly)
12
+ # import socket # Not directly needed now
13
+ import concurrent.futures
14
+ from concurrent.futures import ThreadPoolExecutor
15
+ import shutil # Needed for atomic directory removal
16
+
17
+ # --- Configuration ---
18
+ DATASET_NAME = "virtuoussy/Multi-subject-RLVR" # Or the original source if needed
19
+ DATASET_SPLIT = "train"
20
+ API_HOST = "api2.aigcbest.top"
21
+ API_PATH = "/v1/chat/completions"
22
+ LLM_MODEL = "gpt-4.1-mini"
23
+ API_KEY = os.environ.get('AIGCBEST_API_KEY', "sk-U15cDXxI0bboL6iH4Hymzl30ws6oWzazWe1Ndwq9QtiPUEgI")
24
+ if not API_KEY or API_KEY == "YOUR_API_KEY_HERE":
25
+ print("API Key is not set correctly. Please set the AIGCBEST_API_KEY environment variable or replace the placeholder.")
26
+ sys.exit(1)
27
+
28
+ OUTPUT_DIR = f"./{DATASET_NAME.split('/')[-1]}_rephrased"
29
+ # Path to the existing, potentially incomplete, processed dataset (LOAD ONLY)
30
+ PROCESSED_DATA_PATH = os.path.join(OUTPUT_DIR, f"{DATASET_SPLIT}_processed")
31
+ # Path where intermediate and final results will be saved (SAVE ONLY)
32
+ FINAL_OUTPUT_PATH = os.path.join(OUTPUT_DIR, f"{DATASET_SPLIT}_processed_final")
33
+
34
+ MAX_WORKERS = 20 # Adjust based on your system and API rate limits
35
+ REQUEST_DELAY_SECONDS = 0.15 # Base delay between requests
36
+ MAX_RETRIES = 3 # Max retries for each API call
37
+ SAVE_INTERVAL = 2000 # <<<--- How often to save progress (in number of processed items)
38
+
39
+ # Setup logging
40
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
41
+ logging.getLogger("datasets").setLevel(logging.WARNING)
42
+ logging.getLogger("huggingface_hub").setLevel(logging.WARNING)
43
+ logging.getLogger("filelock").setLevel(logging.WARNING) # Quiet down filelock warnings during save
44
+
45
+ # --- LLM API Function (call_llm_api) ---
46
+ # (No changes needed here, keep the robust version)
47
+ def call_llm_api(original_question, api_key, host, path, model, retries=MAX_RETRIES):
48
+ system_prompt = (
49
+ "You are an expert linguist specializing in converting structured prompts or "
50
+ "fill-in-the-blank problems into natural, spoken-language questions suitable for "
51
+ "text-to-speech (TTS). Your goal is to make the question sound like how a person "
52
+ "would naturally ask it. "
53
+ "If the input is a fill-in-the-blank problem (e.g., contains '-----'), "
54
+ "rephrase it as a direct question asking for the missing information. "
55
+ "Keep the core meaning, mathematical context, variables, and numbers exactly the same. "
56
+ "Focus only on rephrasing the *user's question* part provided. "
57
+ "Output *only* the rephrased question, without any introductory phrases like 'Here's the rephrased question:'."
58
+ )
59
+ payload = json.dumps({
60
+ "model": model,
61
+ "messages": [
62
+ {"role": "system", "content": system_prompt},
63
+ {"role": "user", "content": original_question}
64
+ ],
65
+ })
66
+ headers = {
67
+ 'Accept': 'application/json',
68
+ 'Authorization': f'Bearer {api_key}',
69
+ 'User-Agent': 'HuggingFace Dataset Processing Script (Retry w/ Save)',
70
+ 'Content-Type': 'application/json'
71
+ }
72
+ time.sleep(random.uniform(REQUEST_DELAY_SECONDS * 0.8, REQUEST_DELAY_SECONDS * 1.2))
73
+
74
+ for attempt in range(retries):
75
+ # logging.debug(f"API call attempt {attempt + 1}/{retries} for: {original_question[:50]}...")
76
+ try:
77
+ conn = http.client.HTTPSConnection(host, timeout=60) # Increased timeout
78
+ conn.request("POST", path, payload, headers)
79
+ res = conn.getresponse()
80
+ status = res.status
81
+ data = res.read()
82
+ conn.close()
83
+
84
+ if status == 200:
85
+ response_json = json.loads(data.decode("utf-8"))
86
+ if response_json.get("choices") and len(response_json["choices"]) > 0:
87
+ message = response_json["choices"][0].get("message")
88
+ if message and message.get("content"):
89
+ rephrased = message["content"].strip()
90
+ # Remove surrounding quotes more robustly
91
+ if len(rephrased) > 1:
92
+ if (rephrased.startswith('"') and rephrased.endswith('"')) or \
93
+ (rephrased.startswith("'") and rephrased.endswith("'")):
94
+ rephrased = rephrased[1:-1]
95
+ # Handle cases like 'Rephrased: "..."'
96
+ if rephrased.lower().startswith(("rephrased:", "here's the rephrased question:")):
97
+ parts = rephrased.split(":", 1)
98
+ if len(parts) > 1:
99
+ potential_rephrased = parts[1].strip()
100
+ if (potential_rephrased.startswith('"') and potential_rephrased.endswith('"')) or \
101
+ (potential_rephrased.startswith("'") and potential_rephrased.endswith("'")):
102
+ rephrased = potential_rephrased[1:-1]
103
+ else:
104
+ rephrased = potential_rephrased
105
+
106
+ if rephrased and rephrased.strip().lower() != original_question.strip().lower():
107
+ # logging.debug(f"Successfully rephrased: {rephrased[:80]}...")
108
+ return rephrased
109
+ elif not rephrased:
110
+ logging.warning(f"LLM returned empty/whitespace response for: {original_question[:50]}...")
111
+ return None
112
+ else:
113
+ logging.warning(f"LLM returned identical response for: {original_question[:50]}...")
114
+ return None # Treat identical as failure
115
+ logging.error(f"Unexpected API response structure: {data.decode('utf-8')}")
116
+ return None
117
+ elif status == 429: # Rate limit
118
+ retry_after_header = res.getheader('Retry-After', '5')
119
+ try: wait_time = int(retry_after_header)
120
+ except ValueError: wait_time = REQUEST_DELAY_SECONDS * (2 ** attempt) + random.uniform(1, 5)
121
+ logging.warning(f"Rate limit exceeded (HTTP {status}). Retrying after {wait_time:.2f} seconds...")
122
+ time.sleep(wait_time)
123
+ elif status >= 500: # Server error
124
+ wait_time = REQUEST_DELAY_SECONDS * (2 ** attempt) + random.uniform(1, 5)
125
+ logging.warning(f"Server error (HTTP {status}). Retrying after {wait_time:.2f} seconds...")
126
+ time.sleep(wait_time)
127
+ else: # Other client errors (4xx) - Don't retry these
128
+ logging.error(f"API Client Error: Status {status}, Response: {data.decode('utf-8')}")
129
+ return None
130
+
131
+ except (http.client.HTTPException, ConnectionError, socket.gaierror, TimeoutError, socket.timeout) as e:
132
+ logging.error(f"Network/HTTP error during API call: {e}. Attempt {attempt + 1}/{retries}")
133
+ if attempt + 1 == retries: return None
134
+ wait_time = REQUEST_DELAY_SECONDS * (1.5 ** attempt) + random.uniform(1, 3)
135
+ logging.warning(f"Waiting {wait_time:.2f} seconds before retry...")
136
+ time.sleep(wait_time)
137
+ except json.JSONDecodeError as e:
138
+ logging.error(f"Failed to decode API response: {e}. Response snippet: {data[:200] if data else 'N/A'}")
139
+ if attempt + 1 == retries: return None
140
+ wait_time = REQUEST_DELAY_SECONDS * (2 ** attempt) + random.uniform(1, 5)
141
+ time.sleep(wait_time)
142
+ except Exception as e:
143
+ logging.error(f"An unexpected error occurred during API call: {e}", exc_info=True)
144
+ if attempt + 1 == retries: return None
145
+ wait_time = REQUEST_DELAY_SECONDS * (1.5 ** attempt) + random.uniform(1, 3)
146
+ logging.warning(f"Waiting {wait_time:.2f} seconds before retry...")
147
+ time.sleep(wait_time)
148
+
149
+ logging.error(f"API call failed after {retries} retries for: {original_question[:50]}...")
150
+ return None
151
+
152
+ # --- Dataset Processing Function (rephrase_query_entry) ---
153
+ # (No changes needed here)
154
+ def rephrase_query_entry(example):
155
+ processed_example = example.copy()
156
+ original_query_list = example.get("query")
157
+ processed_example['query_rephrased_status'] = 'processing_retry'
158
+
159
+ if original_query_list is None:
160
+ processed_example['query_rephrased_status'] = 'skipped_missing_query_column'
161
+ processed_example['query_rephrased'] = example.get('query_rephrased') # Keep old value
162
+ return processed_example
163
+ if not isinstance(original_query_list, list):
164
+ processed_example['query_rephrased_status'] = 'skipped_query_not_list'
165
+ processed_example['query_rephrased'] = example.get('query_rephrased') # Keep old value
166
+ return processed_example
167
+ if not original_query_list:
168
+ processed_example['query_rephrased_status'] = 'skipped_query_list_empty'
169
+ processed_example['query_rephrased'] = example.get('query_rephrased') # Keep old value
170
+ return processed_example
171
+
172
+ user_question = None
173
+ for message in original_query_list:
174
+ if isinstance(message, dict) and message.get("role") == "user":
175
+ content = message.get("content")
176
+ if isinstance(content, str) and content.strip():
177
+ user_question = content
178
+ break
179
+ else:
180
+ processed_example['query_rephrased_status'] = 'skipped_invalid_user_content'
181
+ processed_example['query_rephrased'] = example.get('query_rephrased') # Keep old value
182
+ return processed_example
183
+ if not user_question:
184
+ processed_example['query_rephrased_status'] = 'skipped_no_user_content_found'
185
+ processed_example['query_rephrased'] = example.get('query_rephrased') # Keep old value
186
+ return processed_example
187
+
188
+ # logging.info(f"Retrying: {user_question[:60]}...")
189
+ rephrased_query_content = call_llm_api(user_question, API_KEY, API_HOST, API_PATH, LLM_MODEL)
190
+
191
+ if rephrased_query_content:
192
+ processed_example["query_rephrased"] = rephrased_query_content
193
+ processed_example['query_rephrased_status'] = 'success'
194
+ else:
195
+ # Keep the OLD 'query_rephrased' value if LLM call fails this time
196
+ processed_example['query_rephrased'] = example.get('query_rephrased')
197
+ processed_example['query_rephrased_status'] = 'failed_llm_call'
198
+
199
+ return processed_example
200
+
201
+ # --- Function to Save Dataset Atomically ---
202
+ # Saves to a temporary path then renames for safety.
203
+ def save_dataset_atomically(data_list, output_path, features):
204
+ """Saves the list of data dictionaries atomically using the correct schema."""
205
+ if not data_list:
206
+ logging.info("No data provided for saving.")
207
+ return False
208
+
209
+ temp_output_path = output_path + "_saving" # Temporary directory
210
+ final_output_path = output_path
211
+
212
+ logging.info(f"Attempting to save {len(data_list)} examples to temp path {temp_output_path}...")
213
+ try:
214
+ # Create dataset from the list of dictionaries using the defined features
215
+ processed_dataset = Dataset.from_list(list(data_list), features=features) # Convert just in case
216
+
217
+ # Ensure parent directory exists
218
+ os.makedirs(os.path.dirname(final_output_path), exist_ok=True)
219
+
220
+ # Remove any previous temporary directory if it exists
221
+ if os.path.exists(temp_output_path):
222
+ logging.warning(f"Removing existing temporary save directory: {temp_output_path}")
223
+ shutil.rmtree(temp_output_path) # Use shutil for directories
224
+
225
+ # Save the dataset to the temporary path
226
+ processed_dataset.save_to_disk(temp_output_path)
227
+ logging.info(f"Successfully saved dataset to temporary path: {temp_output_path}")
228
+
229
+ # --- Atomic Rename ---
230
+ # Remove the final destination path if it exists
231
+ if os.path.exists(final_output_path):
232
+ logging.debug(f"Removing existing final destination directory before rename: {final_output_path}")
233
+ shutil.rmtree(final_output_path)
234
+
235
+ # Rename the temporary path to the final path
236
+ os.rename(temp_output_path, final_output_path)
237
+ logging.info(f"Successfully moved temporary save to final path: {final_output_path}")
238
+ return True
239
+
240
+ except Exception as e:
241
+ logging.error(f"Failed during atomic save process to {final_output_path}: {e}", exc_info=True)
242
+ # Attempt to clean up temporary directory if it still exists after failure
243
+ if os.path.exists(temp_output_path):
244
+ try:
245
+ shutil.rmtree(temp_output_path)
246
+ logging.info(f"Cleaned up temporary directory {temp_output_path} after error.")
247
+ except Exception as cleanup_e:
248
+ logging.error(f"Could not clean up temporary directory {temp_output_path} after error: {cleanup_e}")
249
+ # Fallback save attempt to JSON Lines (unchanged)
250
+ fallback_json_path = final_output_path + ".jsonl.failed_save" # Indicate it's a fallback
251
+ logging.warning(f"Attempting fallback save to JSON Lines file: {fallback_json_path}")
252
+ try:
253
+ with open(fallback_json_path, 'w', encoding='utf-8') as f:
254
+ for item in data_list:
255
+ f.write(json.dumps(item, ensure_ascii=False) + '\n')
256
+ logging.info(f"Successfully saved fallback JSON Lines file.")
257
+ except Exception as json_e:
258
+ logging.error(f"Fallback JSON save also failed: {json_e}", exc_info=True)
259
+ return False
260
+
261
+ # --- Function to Check if Retry is Needed ---
262
+ # (No changes needed here)
263
+ def needs_retry(example):
264
+ rephrased = example.get('query_rephrased')
265
+ status = example.get('query_rephrased_status')
266
+ # Retry if rephrased is missing OR status is anything other than 'success'
267
+ # This ensures failed/skipped items are retried.
268
+ retry_flag = (rephrased is None) or (status != 'success')
269
+ return retry_flag
270
+
271
+ # --- Main Execution ---
272
+ if __name__ == "__main__":
273
+ start_time = time.time()
274
+ logging.info("======================================================")
275
+ logging.info(f" Starting Dataset Processing - RETRY w/ PERIODIC SAVE")
276
+ logging.info(f" Saving progress every {SAVE_INTERVAL} processed items.")
277
+ logging.info("======================================================")
278
+ logging.info(f"Loading existing data from: {PROCESSED_DATA_PATH}")
279
+ logging.info(f"Intermediate and final output will be saved to: {FINAL_OUTPUT_PATH}")
280
+
281
+ # --- Load Existing Processed Dataset ---
282
+ if not os.path.exists(PROCESSED_DATA_PATH):
283
+ logging.error(f"Existing data directory not found at '{PROCESSED_DATA_PATH}'. Cannot run retry mode.")
284
+ sys.exit(1)
285
+
286
+ logging.info(f"Loading existing dataset from {PROCESSED_DATA_PATH}...")
287
+ try:
288
+ existing_dataset = Dataset.load_from_disk(PROCESSED_DATA_PATH)
289
+ # Get features *before* converting to list
290
+ dataset_features = existing_dataset.features
291
+ logging.info(f"Dataset features detected: {dataset_features}")
292
+ # Convert to a list of dictionaries for in-memory modification
293
+ results_list = existing_dataset.to_list()
294
+ total_examples = len(results_list)
295
+ logging.info(f"Loaded {total_examples} examples.")
296
+ except Exception as e:
297
+ logging.error(f"Failed to load dataset from {PROCESSED_DATA_PATH}: {e}", exc_info=True)
298
+ # Check if the final path exists from a previous run - maybe load that?
299
+ # For now, exiting is safer to avoid inconsistent states.
300
+ # if os.path.exists(FINAL_OUTPUT_PATH):
301
+ # logging.warning(f"Consider manually checking/using the existing data at {FINAL_OUTPUT_PATH}")
302
+ sys.exit(1)
303
+
304
+ # --- Identify Indices to Retry ---
305
+ logging.info("Identifying examples needing retry...")
306
+ indices_to_retry = [
307
+ i for i, example in enumerate(tqdm(results_list, desc="Checking examples")) if needs_retry(example)
308
+ ]
309
+ num_to_retry = len(indices_to_retry)
310
+
311
+ if num_to_retry == 0:
312
+ logging.info("No examples found needing retry based on the criteria ('query_rephrased' is None or status != 'success').")
313
+ logging.info(f"Saving the existing dataset to the final location '{FINAL_OUTPUT_PATH}' as is...")
314
+ if not save_dataset_atomically(results_list, FINAL_OUTPUT_PATH, dataset_features): # Use atomic save
315
+ logging.error("Failed to save the dataset to the final location even though no retries were needed.")
316
+ sys.exit(0)
317
+
318
+ logging.info(f"Identified {num_to_retry} examples to retry out of {total_examples}.")
319
+
320
+ # --- Prepare for Concurrent Retries ---
321
+ processed_count_total = 0 # Total processed in this run
322
+ processed_since_last_save = 0 # Counter for periodic saving
323
+ last_save_time = time.time() # Track time for saving message
324
+
325
+ logging.info("Starting concurrent retries with periodic saving...")
326
+
327
+ # --- ThreadPoolExecutor for Concurrency ---
328
+ with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
329
+ # Submit tasks only for the identified indices
330
+ futures = {
331
+ executor.submit(rephrase_query_entry, results_list[i]): i
332
+ for i in indices_to_retry
333
+ }
334
+
335
+ try:
336
+ # Initialize progress bar for retries
337
+ pbar = tqdm(total=num_to_retry, desc="Retrying examples", unit="example")
338
+ # Process futures as they complete
339
+ for future in concurrent.futures.as_completed(futures):
340
+ original_index = futures[future] # Get the original list index
341
+ try:
342
+ # Get the result (the updated dictionary)
343
+ updated_example_dict = future.result()
344
+ # --- IMMEDIATE UPDATE of the main list ---
345
+ results_list[original_index] = updated_example_dict
346
+ pbar.set_postfix({"LastStatus": updated_example_dict.get('query_rephrased_status', 'N/A')}, refresh=True)
347
+
348
+ except Exception as exc:
349
+ # Catch potential exceptions *from* the rephrase_query_entry function
350
+ logging.error(f'Retry task for index {original_index} encountered an exception: {exc}', exc_info=True)
351
+ # Create an error placeholder and update the main list
352
+ error_placeholder = results_list[original_index].copy() # Start with original data
353
+ error_placeholder['query_rephrased_status'] = f'failed_retry_exception_{type(exc).__name__}'
354
+ # Keep the old query_rephrased value
355
+ results_list[original_index] = error_placeholder
356
+ pbar.set_postfix({"LastStatus": error_placeholder['query_rephrased_status']}, refresh=True)
357
+
358
+ finally:
359
+ # Increment counters and update progress bar
360
+ processed_count_total += 1
361
+ processed_since_last_save += 1
362
+ pbar.update(1)
363
+
364
+ # --- Periodic Save Check ---
365
+ if processed_since_last_save >= SAVE_INTERVAL:
366
+ current_time = time.time()
367
+ time_since_last = current_time - last_save_time
368
+ logging.info(f"\n--- Processed {processed_since_last_save} items (Total: {processed_count_total}/{num_to_retry}). Time since last save: {time_since_last:.1f}s. Saving progress... ---")
369
+ if save_dataset_atomically(results_list, FINAL_OUTPUT_PATH, dataset_features):
370
+ logging.info(f"--- Progress successfully saved to {FINAL_OUTPUT_PATH} ---")
371
+ processed_since_last_save = 0 # Reset counter
372
+ last_save_time = current_time
373
+ else:
374
+ logging.error(f"--- FAILED TO SAVE PROGRESS! Check errors above. Will retry saving later. ---")
375
+ # Don't reset the counter, maybe the next save will work
376
+
377
+ except KeyboardInterrupt:
378
+ logging.warning("\nCtrl+C detected! Attempting final save...")
379
+ # Let the finally block handle the save
380
+
381
+ except Exception as e:
382
+ logging.error(f"An unexpected error occurred during the main retry loop: {e}", exc_info=True)
383
+ logging.error("Attempting final save...")
384
+ # Let the finally block handle the save
385
+
386
+ finally:
387
+ # --- This block executes after the loop finishes, OR if an exception/interrupt occurs ---
388
+ if 'pbar' in locals() and pbar is not None:
389
+ pbar.close()
390
+
391
+ logging.info("--- Processing loop finished or interrupted. ---")
392
+
393
+ # --- Final Save Attempt ---
394
+ # No need to update results_list again, it was updated incrementally.
395
+ logging.info(f"Attempting final save of the dataset ({len(results_list)} items) to: {FINAL_OUTPUT_PATH}")
396
+ if save_dataset_atomically(results_list, FINAL_OUTPUT_PATH, dataset_features):
397
+ logging.info("--- Final dataset state saved successfully. ---")
398
+ else:
399
+ logging.error(">>> FINAL SAVE FAILED! <<< Check logs. Fallback JSON file might exist.")
400
+
401
+ # --- Final Verification (Optional but Recommended) ---
402
+ logging.info("------------------------------------------------------")
403
+ logging.info("Verification: Attempting to load final saved dataset...")
404
+ try:
405
+ final_reloaded_dataset = Dataset.load_from_disk(FINAL_OUTPUT_PATH)
406
+ logging.info(f"Successfully reloaded final dataset with {len(final_reloaded_dataset)} examples from {FINAL_OUTPUT_PATH}.")
407
+
408
+ # Simple status count
409
+ status_counts = {}
410
+ none_rephrased_count = 0
411
+ for ex in final_reloaded_dataset:
412
+ status = ex.get('query_rephrased_status', 'unknown_status')
413
+ status_counts[status] = status_counts.get(status, 0) + 1
414
+ if ex.get('query_rephrased') is None or not str(ex.get('query_rephrased')).strip():
415
+ none_rephrased_count += 1
416
+
417
+ logging.info("Final status counts:")
418
+ for status, count in sorted(status_counts.items()):
419
+ logging.info(f" - {status}: {count}")
420
+
421
+ final_success = status_counts.get('success', 0)
422
+ final_failed = sum(count for st, count in status_counts.items() if st and (st.startswith('failed_') or st == 'processing_retry')) # Items potentially stuck
423
+ final_skipped = sum(count for st, count in status_counts.items() if st and st.startswith('skipped_'))
424
+ other_count = len(final_reloaded_dataset) - final_success - final_failed - final_skipped
425
+
426
+ logging.info(f"Summary: Success={final_success}, Failed/Incomplete={final_failed}, Skipped={final_skipped}, Other={other_count}")
427
+ if none_rephrased_count > 0:
428
+ logging.warning(f"WARNING: {none_rephrased_count} items have None/empty 'query_rephrased' in the final dataset.")
429
+ if final_failed > 0:
430
+ logging.warning(f"WARNING: {final_failed} items did not reach 'success' or 'skipped' status.")
431
+
432
+
433
+ except FileNotFoundError:
434
+ logging.error(f"Verification failed: Final dataset directory not found at {FINAL_OUTPUT_PATH}. Final save likely failed.")
435
+ except Exception as e:
436
+ logging.error(f"Verification failed: Could not reload/verify final dataset from {FINAL_OUTPUT_PATH}: {e}", exc_info=True)
437
+
438
+ # --- Script End ---
439
+ end_time = time.time()
440
+ logging.info("------------------------------------------------------")
441
+ logging.info(f"Script finished in {end_time - start_time:.2f} seconds.")
442
+ logging.info("======================================================")
r1-a/dataset/retts.py ADDED
@@ -0,0 +1,559 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import os
3
+ import argparse
4
+ import torch
5
+ import re
6
+ import jiwer
7
+ from datasets import load_from_disk, concatenate_datasets, Dataset, Features, Value, Audio # Keep Audio for potential output type hint if needed
8
+ from transformers import pipeline
9
+ import logging
10
+ import time
11
+ import soundfile as sf # For checking validity
12
+ import librosa # For loading audio in batch function
13
+ import numpy as np
14
+ import collections
15
+ import pyarrow as pa
16
+
17
+ # --- 配置日志 ---
18
+ log_formatter = logging.Formatter('%(asctime)s - %(levelname)s - [Shard %(shard_index)s] - %(message)s')
19
+ logger = logging.getLogger()
20
+ logger.setLevel(logging.INFO)
21
+ ch = logging.StreamHandler()
22
+ ch.setLevel(logging.INFO)
23
+ ch.setFormatter(log_formatter)
24
+ logger.addHandler(ch)
25
+ fh = None # File handler setup in main
26
+
27
+ # --- 常量与参数定义 ---
28
+ MODEL_ID = "openai/whisper-large-v3"
29
+ DATASET_PATH = "/home/chenyifu/audio-r1/r1-a/final_dataset/preference_relative" # ADJUST IF NEEDED
30
+ OUTPUT_DIR = "/home/chenyifu/audio-r1/r1-a/final_dataset/preference_relative_processed_shards" # ADJUST IF NEEDED
31
+ LOG_DIR = os.path.join(OUTPUT_DIR, "logs")
32
+ NUM_SHARDS = 50
33
+ MIN_AUDIO_DURATION_MS = 100
34
+ TARGET_SR = 16000 # Whisper expected sample rate
35
+
36
+ # --- 文本规范化函数 ---
37
+ def normalize_text(text):
38
+ if text is None:
39
+ return ""
40
+ text = str(text).lower()
41
+ text = re.sub(r'[^\w\s\u4e00-\u9fff,。!?、]', '', text)
42
+ text = re.sub(r'\s+', ' ', text).strip()
43
+ return text
44
+
45
+ # --- 音频文件预检查函数 (检查路径) ---
46
+ # (This function remains largely the same as it already worked with paths)
47
+ def check_audio_file_validity(example, shard_idx_for_log=None):
48
+ is_valid = False
49
+ error_msg = "Unknown error"
50
+ duration_ms = 0
51
+ audio_path = example.get("question_audio") # Directly get the path string
52
+ log_prefix = f"[Shard {shard_idx_for_log}] " if shard_idx_for_log is not None else ""
53
+
54
+ if audio_path and isinstance(audio_path, str):
55
+ if os.path.exists(audio_path):
56
+ try:
57
+ info = sf.info(audio_path)
58
+ duration_ms = int(info.duration * 1000)
59
+ if info.samplerate > 0 and info.frames > 0:
60
+ if duration_ms >= MIN_AUDIO_DURATION_MS:
61
+ is_valid = True
62
+ error_msg = None
63
+ else:
64
+ error_msg = f"Audio duration {duration_ms}ms < minimum {MIN_AUDIO_DURATION_MS}ms"
65
+ else:
66
+ error_msg = "Invalid audio properties (samplerate/frames <= 0)"
67
+ except Exception as e:
68
+ logger.warning(f"{log_prefix}Cannot read info/validate file {audio_path}: {type(e).__name__}")
69
+ error_msg = f"Cannot read/validate audio: {type(e).__name__}"
70
+ else:
71
+ error_msg = "Audio file not found"
72
+ elif audio_path is None:
73
+ error_msg = "Audio path is missing or null"
74
+ else:
75
+ error_msg = f"Audio path is not a string (type: {type(audio_path).__name__})"
76
+
77
+ return {
78
+ "audio_is_valid": is_valid,
79
+ "audio_check_error": error_msg,
80
+ "audio_duration_ms": duration_ms
81
+ # Don't add the original path back here, it's already in the dataset
82
+ }
83
+
84
+
85
+ # --- 核心处理函数 (批处理 - 加载音频路径) ---
86
+ def check_audio_quality_batch(batch, asr_pipeline, wer_threshold, target_sr, shard_idx_for_log=None):
87
+ log_prefix = f"[Shard {shard_idx_for_log}] " if shard_idx_for_log is not None else ""
88
+ results = {"asr_transcription": [], "wer": [], "is_bad_tts": [], "error_message": []}
89
+ original_texts = batch.get("question_text", [])
90
+ audio_paths = batch.get("question_audio", []) # Get list of paths
91
+
92
+ num_samples_in_batch = len(audio_paths)
93
+ if not audio_paths or not original_texts or len(audio_paths) != len(original_texts):
94
+ logger.warning(f"{log_prefix}Batch inconsistency or empty data. Paths: {len(audio_paths)}, Text: {len(original_texts)}")
95
+ num_samples = max(len(audio_paths), len(original_texts))
96
+ results["asr_transcription"] = [""] * num_samples
97
+ results["wer"] = [1.0] * num_samples
98
+ results["is_bad_tts"] = [True] * num_samples
99
+ results["error_message"] = ["Inconsistent batch data or missing paths/text"] * num_samples
100
+ return results
101
+
102
+ batch_load_start_time = time.time()
103
+ loaded_audios = []
104
+ load_errors = [None] * num_samples_in_batch # Track loading errors per sample
105
+
106
+ # --- 加载批次中的所有音频 ---
107
+ for i, path in enumerate(audio_paths):
108
+ try:
109
+ if not path or not isinstance(path, str):
110
+ raise ValueError("Invalid audio path")
111
+ # Load using librosa, force mono, resample to target_sr
112
+ audio_array, sample_rate = librosa.load(path, sr=target_sr, mono=True)
113
+ loaded_audios.append(audio_array)
114
+ except Exception as e:
115
+ logger.warning(f"{log_prefix}Failed to load audio file '{path}': {type(e).__name__}. Skipping for ASR.")
116
+ loaded_audios.append(None) # Use None as placeholder for failed loads
117
+ load_errors[i] = f"Audio load failed: {type(e).__name__}"
118
+
119
+ batch_load_end_time = time.time()
120
+ logger.debug(f"{log_prefix}Loaded {len([a for a in loaded_audios if a is not None])}/{num_samples_in_batch} audios in {batch_load_end_time - batch_load_start_time:.2f} sec.")
121
+
122
+ # Filter out None placeholders before sending to ASR pipeline?
123
+ # Option 1: Send only valid audios (might complicate matching results back)
124
+ # Option 2: Send list including None/empty arrays, let pipeline handle (or pre-handle)
125
+ # Let's try Option 2 with pre-handling: Replace None with empty array for pipeline input
126
+ pipeline_inputs = []
127
+ valid_indices = [] # Track indices of samples sent to pipeline
128
+ for i, audio_data in enumerate(loaded_audios):
129
+ if audio_data is not None and len(audio_data) > 0: # Check if loading succeeded and audio not empty
130
+ pipeline_inputs.append(audio_data)
131
+ valid_indices.append(i)
132
+ # else: keep load_errors[i] message
133
+
134
+ asr_results_list = [None] * num_samples_in_batch # Initialize results list matching original batch size
135
+
136
+ # --- ASR 推理 (仅对成功加载的音频) ---
137
+ if pipeline_inputs: # Only run pipeline if there are valid audios
138
+ batch_asr_start_time = time.time()
139
+ try:
140
+ # Pass the list of NumPy arrays directly to the pipeline
141
+ asr_outputs = asr_pipeline(pipeline_inputs, generate_kwargs={"language": "zh", "task": "transcribe"})
142
+
143
+ if not isinstance(asr_outputs, list):
144
+ asr_outputs = [asr_outputs] # Ensure it's a list
145
+
146
+ # Map results back to original batch positions using valid_indices
147
+ if len(asr_outputs) == len(valid_indices):
148
+ for idx, result in zip(valid_indices, asr_outputs):
149
+ asr_results_list[idx] = result # Place result at the correct original index
150
+ else:
151
+ logger.error(f"{log_prefix}ASR output count ({len(asr_outputs)}) mismatch with valid input count ({len(valid_indices)}). Marking all in batch as error.")
152
+ # Mark all samples in the batch with an error if counts mismatch
153
+ for i in range(num_samples_in_batch):
154
+ if load_errors[i] is None: # If loading didn't fail, mark as ASR mismatch
155
+ load_errors[i] = "ASR count mismatch error"
156
+
157
+ # --- Error Handling for ASR Pipeline ---
158
+ # Catch errors specifically from the pipeline call
159
+ except ValueError as ve: # e.g., internal batching errors if any remain
160
+ logger.error(f"{log_prefix}ValueError during ASR pipeline processing: {ve}", exc_info=True)
161
+ for idx in valid_indices: # Mark only those sent to pipeline as failed
162
+ asr_results_list[idx] = "ERROR: ASR ValueError" # Placeholder or error indicator
163
+ if load_errors[idx] is None: load_errors[idx] = f"ASR ValueError: {str(ve)[:100]}"
164
+ except torch.cuda.OutOfMemoryError:
165
+ logger.error(f"{log_prefix}CUDA OutOfMemoryError during ASR batch processing.")
166
+ torch.cuda.empty_cache()
167
+ for idx in valid_indices:
168
+ asr_results_list[idx] = "ERROR: ASR OOM"
169
+ if load_errors[idx] is None: load_errors[idx] = "ASR CUDA OOM"
170
+ except Exception as e:
171
+ logger.error(f"{log_prefix}Exception during ASR pipeline processing: {e}", exc_info=True)
172
+ for idx in valid_indices:
173
+ asr_results_list[idx] = "ERROR: ASR Exception"
174
+ if load_errors[idx] is None: load_errors[idx] = f"ASR Exception: {str(e)[:100]}"
175
+
176
+ batch_asr_end_time = time.time()
177
+ logger.debug(f"{log_prefix}ASR processed {len(valid_indices)} audios in {batch_asr_end_time - batch_asr_start_time:.2f} sec.")
178
+
179
+ # --- 计算 WER (遍历原始批次大小) ---
180
+ for i in range(num_samples_in_batch):
181
+ transcription = ""
182
+ wer = 1.0 # Default to max error
183
+ is_bad = True
184
+ error_msg = load_errors[i] # Start with potential loading error
185
+
186
+ asr_result = asr_results_list[i]
187
+
188
+ if error_msg is None: # If no loading error, proceed with ASR result
189
+ if isinstance(asr_result, dict) and "text" in asr_result:
190
+ transcription = asr_result["text"]
191
+ original_text = original_texts[i]
192
+
193
+ norm_original = normalize_text(original_text)
194
+ norm_transcription = normalize_text(transcription)
195
+
196
+ if not norm_original:
197
+ wer = 1.0 if norm_transcription else 0.0
198
+ is_bad = True if norm_transcription else False
199
+ error_msg = "Original text normalized to empty" if is_bad else "Original text normalized to empty, transcription also empty"
200
+ else:
201
+ try:
202
+ wer = jiwer.wer(norm_original, norm_transcription)
203
+ wer = min(wer, 1.0) # Clamp WER
204
+ is_bad = wer > wer_threshold
205
+ except ValueError as e:
206
+ wer = 1.0
207
+ is_bad = True
208
+ logger.warning(f"{log_prefix}Jiwer WER calculation error for idx {i}. Setting WER to 1.0. Error: {e}")
209
+ error_msg = f"WER calculation error: {e}"
210
+ except Exception as e:
211
+ wer = 1.0
212
+ is_bad = True
213
+ logger.error(f"{log_prefix}Unexpected error during WER calculation idx {i}: {e}", exc_info=True)
214
+ error_msg = f"Unexpected WER error: {e}"
215
+ elif isinstance(asr_result, str) and "ERROR:" in asr_result:
216
+ # Handle error strings passed from ASR exception handling
217
+ error_msg = asr_result
218
+ wer = 1.0
219
+ is_bad = True
220
+ else:
221
+ # ASR didn't run (load failed) or returned unexpected format
222
+ # error_msg should already be set from load_errors
223
+ # If error_msg is somehow still None, set a generic one
224
+ if error_msg is None:
225
+ error_msg = "ASR did not produce valid output"
226
+ wer = 1.0
227
+ is_bad = True
228
+
229
+ results["asr_transcription"].append(transcription)
230
+ results["wer"].append(wer)
231
+ results["is_bad_tts"].append(is_bad)
232
+ results["error_message"].append(error_msg)
233
+
234
+ return results
235
+
236
+
237
+ # --- 统计信息记录函数 ---
238
+ # (This function remains the same, as it operates on the processed shard data)
239
+ def log_shard_statistics(processed_shard, shard_index, wer_threshold, processing_time):
240
+ log_prefix = f"[Shard {shard_index}] "
241
+ logger.info(f"{log_prefix}--- Shard {shard_index} Statistics ---")
242
+ # ... (rest of the function is identical to the previous version) ...
243
+ total_samples = len(processed_shard)
244
+ logger.info(f"{log_prefix}Total samples processed in this shard: {total_samples}")
245
+ if total_samples == 0:
246
+ logger.info(f"{log_prefix}Shard was empty, no statistics to report.")
247
+ logger.info(f"{log_prefix}--- End Shard {shard_index} Statistics ---")
248
+ return
249
+
250
+ logger.info(f"{log_prefix}Processing time for this shard: {processing_time:.2f} seconds")
251
+ if processing_time > 0:
252
+ logger.info(f"{log_prefix}Overall processing speed: {total_samples / processing_time:.2f} samples/sec")
253
+ logger.info(f"{log_prefix}WER threshold used: {wer_threshold}")
254
+
255
+ required_cols = ['is_bad_tts', 'wer', 'error_message', 'question_text', 'asr_transcription']
256
+ if not all(col in processed_shard.column_names for col in required_cols):
257
+ logger.error(f"{log_prefix}Processed shard is missing required columns for statistics ({required_cols}). Skipping detailed stats.")
258
+ logger.info(f"{log_prefix}--- End Shard {shard_index} Statistics ---")
259
+ return
260
+
261
+ try:
262
+ bad_tts_count = sum(processed_shard['is_bad_tts'])
263
+ bad_tts_percentage = (bad_tts_count / total_samples) * 100 if total_samples > 0 else 0
264
+ logger.info(f"{log_prefix}Bad TTS samples (WER > {wer_threshold} or Error): {bad_tts_count} ({bad_tts_percentage:.2f}%)")
265
+ logger.info(f"{log_prefix}Good TTS samples (WER <= {wer_threshold}): {total_samples - bad_tts_count} ({100 - bad_tts_percentage:.2f}%)")
266
+
267
+ wer_scores = [w for w in processed_shard['wer'] if w is not None and not np.isnan(w)]
268
+ if wer_scores:
269
+ logger.info(f"{log_prefix}WER Score Distribution (for samples where WER could be calculated):")
270
+ logger.info(f"{log_prefix} Count: {len(wer_scores)}")
271
+ logger.info(f"{log_prefix} Min: {np.min(wer_scores):.4f}")
272
+ logger.info(f"{log_prefix} Max: {np.max(wer_scores):.4f}") # Should be <= 1.0 now
273
+ logger.info(f"{log_prefix} Mean: {np.mean(wer_scores):.4f}")
274
+ logger.info(f"{log_prefix} Median: {np.median(wer_scores):.4f}")
275
+ q25, q75 = np.percentile(wer_scores, [25, 75])
276
+ logger.info(f"{log_prefix} 25th Percentile: {q25:.4f}")
277
+ logger.info(f"{log_prefix} 75th Percentile: {q75:.4f}")
278
+ else:
279
+ logger.info(f"{log_prefix}WER Score Distribution: No valid WER scores found.")
280
+
281
+ error_messages = [msg for msg in processed_shard['error_message'] if msg]
282
+ if error_messages:
283
+ error_counts = collections.Counter(error_messages)
284
+ logger.info(f"{log_prefix}Error Message Summary (Top 10):")
285
+ for msg, count in error_counts.most_common(10):
286
+ logger.info(f"{log_prefix} - \"{msg}\": {count} occurrences")
287
+ if len(error_counts) > 10:
288
+ logger.info(f"{log_prefix} ... ({len(error_counts) - 10} more error types)")
289
+ else:
290
+ logger.info(f"{log_prefix}Error Message Summary: No processing errors recorded.")
291
+
292
+ logger.info(f"\n{log_prefix}--- Example Good TTS Samples (WER <= {wer_threshold}) ---")
293
+ # Use select for potentially large datasets, disable caching for filter
294
+ good_samples_indices = [i for i, bad in enumerate(processed_shard['is_bad_tts']) if not bad]
295
+ num_good_to_show = min(5, len(good_samples_indices))
296
+ if num_good_to_show > 0:
297
+ # Select the samples using indices; this is faster than filter for small selects
298
+ good_samples_view = processed_shard.select(good_samples_indices[:num_good_to_show])
299
+ for i in range(num_good_to_show):
300
+ sample = good_samples_view[i]
301
+ logger.info(f"{log_prefix} Example {i+1}:")
302
+ logger.info(f"{log_prefix} Original Text: {sample['question_text']}")
303
+ logger.info(f"{log_prefix} ASR Transcript: {sample['asr_transcription']}")
304
+ logger.info(f"{log_prefix} WER: {sample['wer']:.4f}")
305
+ logger.info(f"{log_prefix} Audio Path: {sample['question_audio']}") # Show path
306
+ else:
307
+ logger.info(f"{log_prefix} No good samples found in this shard.")
308
+
309
+
310
+ logger.info(f"\n{log_prefix}--- Example Bad TTS Samples (WER > {wer_threshold} or Error) ---")
311
+ bad_samples_indices = [i for i, bad in enumerate(processed_shard['is_bad_tts']) if bad]
312
+ num_bad_to_show = min(5, len(bad_samples_indices))
313
+ if num_bad_to_show > 0:
314
+ bad_samples_view = processed_shard.select(bad_samples_indices[:num_bad_to_show])
315
+ for i in range(num_bad_to_show):
316
+ sample = bad_samples_view[i]
317
+ logger.info(f"{log_prefix} Example {i+1}:")
318
+ logger.info(f"{log_prefix} Original Text: {sample['question_text']}")
319
+ logger.info(f"{log_prefix} ASR Transcript: {sample['asr_transcription']}")
320
+ logger.info(f"{log_prefix} WER: {sample['wer']:.4f}")
321
+ logger.info(f"{log_prefix} Error Msg: {sample['error_message']}")
322
+ logger.info(f"{log_prefix} Audio Path: {sample['question_audio']}") # Show path
323
+ else:
324
+ logger.info(f"{log_prefix} No bad samples found in this shard.")
325
+
326
+ except Exception as e:
327
+ logger.error(f"{log_prefix}Error generating statistics: {e}", exc_info=True)
328
+
329
+ logger.info(f"{log_prefix}--- End Shard {shard_index} Statistics ---")
330
+
331
+
332
+ # --- 主函数 ---
333
+ def main():
334
+ global fh
335
+ parser = argparse.ArgumentParser(description="Process a shard of the dataset using Whisper ASR, loading audio from paths.")
336
+ # ... (Argument parsing remains the same) ...
337
+ parser.add_argument("--shard_index", type=int, required=True, help=f"Index of the shard to process (0 to {NUM_SHARDS-1}).")
338
+ parser.add_argument("--gpu_id", type=int, required=True, help="GPU ID to use for this process.")
339
+ parser.add_argument("--wer_threshold", type=float, default=0.4, help="WER threshold to mark TTS as bad.")
340
+ parser.add_argument("--pipeline_batch_size", type=int, default=8, help="Internal batch size for the ASR pipeline.")
341
+ parser.add_argument("--map_batch_size", type=int, default=16, help="Batch size for the datasets.map function (how many rows passed to batch func).")
342
+ parser.add_argument("--num_check_workers", type=int, default=4, help="Number of workers for audio pre-check map.")
343
+ args = parser.parse_args()
344
+ shard_index = args.shard_index
345
+ # ... (Rest of argument setup, logging setup, GPU setup - same as before) ...
346
+ gpu_id = args.gpu_id
347
+ wer_threshold = args.wer_threshold
348
+ pipeline_batch_size = args.pipeline_batch_size
349
+ map_batch_size = args.map_batch_size
350
+ num_check_workers = args.num_check_workers
351
+
352
+ os.makedirs(LOG_DIR, exist_ok=True)
353
+ log_file = os.path.join(LOG_DIR, f"shard_{shard_index}_gpu_{gpu_id}.log")
354
+ fh = logging.FileHandler(log_file, mode='w')
355
+ fh.setLevel(logging.INFO)
356
+ fh.setFormatter(log_formatter)
357
+ logger.addHandler(fh)
358
+
359
+ old_factory = logging.getLogRecordFactory()
360
+ def record_factory(*args, **kwargs):
361
+ record = old_factory(*args, **kwargs)
362
+ record.shard_index = shard_index
363
+ return record
364
+ logging.setLogRecordFactory(record_factory)
365
+
366
+ os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
367
+ device = f"cuda:0"
368
+ logger.info(f"Process started for Shard {shard_index} on GPU {gpu_id} (logical device {device})")
369
+ logger.info(f"Arguments: {args}")
370
+
371
+ processed_shard = None
372
+ processing_time = 0
373
+
374
+ try:
375
+ # --- 加载完整数据集 ---
376
+ logger.info(f"Loading dataset from {DATASET_PATH}")
377
+ try:
378
+ full_ds = load_from_disk(DATASET_PATH)
379
+ breakpoint()
380
+ logger.info(f"Full dataset loaded with {full_ds.num_rows} rows.")
381
+ # Check the feature type of question_audio - SHOULD BE string
382
+ if 'question_audio' not in full_ds.features:
383
+ logger.error("Dataset loaded, but required 'question_audio' column is missing!")
384
+ return
385
+ logger.info(f"Feature 'question_audio': {full_ds.features['question_audio']}")
386
+ if not isinstance(full_ds.features['question_audio'], Value) or full_ds.features['question_audio'].dtype != 'string':
387
+ logger.warning(f"'question_audio' column type is not string ({full_ds.features['question_audio']}). Attempting to proceed, but expecting paths.")
388
+
389
+ except Exception as e:
390
+ logger.error(f"Failed to load dataset: {e}", exc_info=True)
391
+ return
392
+
393
+ # --- 数据预处理:检查音频文件有效性 (on paths) ---
394
+ logger.info(f"Checking audio file validity (min duration: {MIN_AUDIO_DURATION_MS}ms)...")
395
+ check_features = Features({
396
+ **full_ds.features,
397
+ 'audio_is_valid': Value('bool'),
398
+ 'audio_check_error': Value('string'),
399
+ 'audio_duration_ms': Value('int64')
400
+ })
401
+ num_check_workers = max(1, min(num_check_workers, os.cpu_count()))
402
+ logger.info(f"Using {num_check_workers} workers for audio check.")
403
+ full_ds_checked = full_ds.map(
404
+ check_audio_file_validity,
405
+ num_proc=num_check_workers,
406
+ features=check_features,
407
+ batched=False,
408
+ fn_kwargs={"shard_idx_for_log": shard_index}
409
+ )
410
+ logger.info("Audio validity check complete.")
411
+
412
+ # --- 过滤掉无效音频 ---
413
+ original_count = len(full_ds_checked)
414
+ valid_audio_ds = full_ds_checked.filter(
415
+ lambda x: x['audio_is_valid'],
416
+ num_proc=num_check_workers,
417
+ load_from_cache_file=False
418
+ )
419
+ filtered_count = original_count - len(valid_audio_ds)
420
+ logger.info(f"Filtered out {filtered_count} samples based on path validity/duration. Kept {len(valid_audio_ds)} samples.")
421
+
422
+ # Log filtering reasons (same as before)
423
+ if filtered_count > 0:
424
+ # Avoid running another potentially slow filter just for logging
425
+ logger.warning("Logging top filtering reasons (based on initial check results, sample limit applies if dataset large)...")
426
+ try:
427
+ error_reasons = collections.Counter(r['audio_check_error'] for r in full_ds_checked.filter(lambda x: not x['audio_is_valid'], load_from_cache_file=False).select(range(min(1000, filtered_count))))
428
+ for reason, count in error_reasons.most_common(10):
429
+ if reason: # Don't log None reasons if any slip through
430
+ logger.warning(f" - {reason}: {count} samples")
431
+ except Exception as log_e:
432
+ logger.warning(f"Could not retrieve filtering reasons: {log_e}")
433
+
434
+
435
+ if len(valid_audio_ds) == 0:
436
+ logger.error("No valid audio samples found after filtering. Exiting.")
437
+ return
438
+
439
+ # --- !! REMOVED cast_column step !! ---
440
+ # The 'question_audio' column remains as paths in valid_audio_ds
441
+
442
+ # --- 获取当前进程需要处理的分片 ---
443
+ logger.info(f"Creating shard {shard_index} from valid audio data (paths)...")
444
+ ds_shard = valid_audio_ds.shard(num_shards=NUM_SHARDS, index=shard_index, contiguous=True)
445
+ logger.info(f"Shard {shard_index} created with {ds_shard.num_rows} rows.")
446
+ # Log features to confirm 'question_audio' is still string
447
+ logger.info(f"Shard features: {ds_shard.features}")
448
+
449
+
450
+ if ds_shard.num_rows == 0:
451
+ logger.warning(f"Shard {shard_index} is empty after sharding. Saving empty structure and exiting process.")
452
+ # Define empty output features (keeping original path column)
453
+ final_features = Features({
454
+ **ds_shard.features, # Includes original columns like question_audio (path)
455
+ 'asr_transcription': Value('string'),
456
+ 'wer': Value('float32'),
457
+ 'is_bad_tts': Value('bool'),
458
+ 'error_message': Value('string')
459
+ })
460
+ # Remove check columns from features before creating empty table
461
+ final_features.pop('audio_is_valid', None)
462
+ final_features.pop('audio_check_error', None)
463
+ final_features.pop('audio_duration_ms', None)
464
+
465
+ shard_output_path = os.path.join(OUTPUT_DIR, f"shard_{shard_index}")
466
+ os.makedirs(shard_output_path, exist_ok=True)
467
+ try:
468
+ empty_table = pa.Table.from_pydict({}, schema=final_features.arrow_schema)
469
+ empty_ds = Dataset(arrow_table=empty_table)
470
+ empty_ds.save_to_disk(shard_output_path)
471
+ logger.info(f"Saved empty dataset structure for shard {shard_index}.")
472
+ except Exception as save_e:
473
+ logger.error(f"Could not save empty dataset structure for shard {shard_index}: {save_e}")
474
+ processed_shard = empty_ds # Set processed_shard for stats
475
+ return # Exit after handling empty shard
476
+
477
+ # --- 加载ASR Pipeline ---
478
+ logger.info(f"Loading ASR pipeline {MODEL_ID} on {device}...")
479
+ try:
480
+ asr_pipeline = pipeline(
481
+ "automatic-speech-recognition",
482
+ model=MODEL_ID,
483
+ torch_dtype=torch.float16,
484
+ device=device,
485
+ batch_size=pipeline_batch_size # Pipeline's internal batch size
486
+ )
487
+ logger.info(f"ASR pipeline loaded successfully with internal batch size {pipeline_batch_size}.")
488
+ except Exception as e:
489
+ logger.error(f"Failed to load ASR pipeline: {e}", exc_info=True)
490
+ return
491
+
492
+ # --- 使用 map 处理分片数据 ---
493
+ logger.info(f"Starting processing shard {shard_index} with map batch size {map_batch_size} and WER threshold {wer_threshold}...")
494
+ start_time = time.time()
495
+ # Define output features: Keep original columns + add new ones
496
+ output_features = Features({
497
+ **ds_shard.features, # Keep original columns (incl. question_audio path)
498
+ 'asr_transcription': Value('string'),
499
+ 'wer': Value('float32'),
500
+ 'is_bad_tts': Value('bool'),
501
+ 'error_message': Value('string')
502
+ })
503
+ # Remove check columns from output features
504
+ output_features.pop('audio_is_valid', None)
505
+ output_features.pop('audio_check_error', None)
506
+ output_features.pop('audio_duration_ms', None)
507
+
508
+ processed_shard = ds_shard.map(
509
+ check_audio_quality_batch,
510
+ batched=True,
511
+ batch_size=map_batch_size, # map's batch size (rows passed to func)
512
+ fn_kwargs={
513
+ "asr_pipeline": asr_pipeline,
514
+ "wer_threshold": wer_threshold,
515
+ "target_sr": TARGET_SR,
516
+ "shard_idx_for_log": shard_index
517
+ },
518
+ features=output_features, # Define output schema
519
+ load_from_cache_file=False, # Disable caching
520
+ remove_columns=['audio_is_valid', 'audio_check_error', 'audio_duration_ms'] # Remove check columns during map
521
+ )
522
+ end_time = time.time()
523
+ processing_time = end_time - start_time
524
+ logger.info(f"Shard {shard_index} processing finished in {processing_time:.2f} seconds.")
525
+ logger.info(f"Processed shard {shard_index} has columns: {processed_shard.column_names}")
526
+
527
+ # --- 保存处理后的分片 ---
528
+ # No need to remove check columns here, done in map
529
+ shard_output_path = os.path.join(OUTPUT_DIR, f"shard_{shard_index}")
530
+ logger.info(f"Saving processed shard {shard_index} to {shard_output_path}...")
531
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
532
+ try:
533
+ processed_shard.save_to_disk(shard_output_path)
534
+ logger.info(f"Shard {shard_index} saved successfully.")
535
+ except Exception as e:
536
+ # Check specifically for Arrow serialization issues if they occur
537
+ logger.error(f"Failed to save processed shard {shard_index} to {shard_output_path}: {e}", exc_info=True)
538
+ # The IndexError related to soundfile should NOT happen now
539
+
540
+ finally:
541
+ # --- 记录统计信息 ---
542
+ if processed_shard is not None:
543
+ log_shard_statistics(processed_shard, shard_index, wer_threshold, processing_time)
544
+ else:
545
+ logger.warning("Processing did not complete or failed early. No statistics to log.")
546
+
547
+ logger.info(f"Process for Shard {shard_index} on GPU {gpu_id} finished.")
548
+ if fh:
549
+ logger.removeHandler(fh)
550
+ fh.close()
551
+
552
+ if __name__ == "__main__":
553
+ # Add librosa to requirements check potentially
554
+ try:
555
+ import librosa
556
+ except ImportError:
557
+ print("Error: librosa is required. Please install it using: pip install librosa")
558
+ exit(1)
559
+ main()
r1-a/dataset/sciq.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import torch
4
+ import torchaudio
5
+ from datasets import load_dataset, Dataset
6
+ import sys
7
+ from tqdm import tqdm
8
+
9
+ sys.path.append('/root/autodl-tmp/CosyVoice')
10
+ from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
11
+ from cosyvoice.utils.file_utils import load_wav
12
+
13
+ # ------------------------
14
+ # 配置参数
15
+ # ------------------------
16
+ COMMON_VOICE_LANGUAGE = "en"
17
+ DATASET_NAME = "sciq" # 目标数据集:SciQ
18
+ OUTPUT_DATASET_PATH = './sciq_with_audio'
19
+ SAMPLE_RATE = 16000
20
+
21
+ # ------------------------
22
+ # 辅助函数
23
+ # ------------------------
24
+ def get_random_prompt(common_voice_dataset, sample_rate=SAMPLE_RATE):
25
+ """
26
+ 从 VoxPopuli (替代原 Common Voice) 数据集中随机抽取一条语音及对应文本作为 prompt。
27
+ """
28
+ idx = random.randint(0, len(common_voice_dataset) - 1)
29
+ sample = common_voice_dataset.select([idx])[0]
30
+ audio = sample['audio']
31
+ waveform = torch.tensor(audio['array'], dtype=torch.float32)
32
+ sr = audio['sampling_rate']
33
+ if sr != sample_rate:
34
+ resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate)
35
+ waveform = resampler(waveform)
36
+ return waveform.unsqueeze(0), sample['raw_text']
37
+
38
+ def text_to_audio(query_text, cosyvoice, common_voice_dataset, stream=False):
39
+ """
40
+ 利用 CosyVoice2 模型将输入文本转换为语音,采用随机 prompt 进行零样本推理。
41
+ """
42
+ try:
43
+ prompt_speech, prompt_text = get_random_prompt(common_voice_dataset, sample_rate=SAMPLE_RATE)
44
+ # 可选:保存 prompt.wav 以做调试
45
+ # torchaudio.save('prompt.wav', prompt_speech, SAMPLE_RATE)
46
+
47
+ all_speech = []
48
+ for i, j in enumerate(cosyvoice.inference_zero_shot(
49
+ query_text,
50
+ prompt_text,
51
+ prompt_speech,
52
+ stream=stream,
53
+ text_frontend=False
54
+ )):
55
+ all_speech.append(j['tts_speech'])
56
+
57
+ # 将所有生成的语音片段拼接在一起
58
+ combined_speech = torch.cat(all_speech, dim=-1)
59
+ sample_rate_val = cosyvoice.sample_rate
60
+
61
+ return {
62
+ 'audio_tensor': combined_speech,
63
+ 'sample_rate': sample_rate_val
64
+ }
65
+ except Exception as e:
66
+ print(f"Error converting text to audio: {e}")
67
+ return None
68
+
69
+ def process_example(example, cosyvoice, common_voice_dataset, sample_rate=SAMPLE_RATE):
70
+ """
71
+ 针对 SciQ 数据集中的单个样本进行 TTS 处理。
72
+ 假设我们只对 sample['question'] 做 TTS。
73
+ """
74
+ query = example['question'] # 可根据需要修改要转换的文本字段
75
+ audio_result = text_to_audio(query, cosyvoice, common_voice_dataset, stream=False)
76
+ if audio_result is not None:
77
+ return {
78
+ 'audio_tensor': audio_result['audio_tensor'],
79
+ 'sample_rate': audio_result['sample_rate']
80
+ }
81
+ else:
82
+ return None
83
+
84
+ # ------------------------
85
+ # 数据加载与模型初始化
86
+ # ------------------------
87
+ print("Loading VoxPopuli (as Common Voice) dataset...")
88
+ common_voice = load_dataset("facebook/voxpopuli", "en", split='train')
89
+ print(f"Total VoxPopuli {COMMON_VOICE_LANGUAGE} samples: {len(common_voice)}")
90
+
91
+ print("Initializing CosyVoice2 model...")
92
+ cosyvoice = CosyVoice2(
93
+ '/root/autodl-tmp/CosyVoice/pretrained_models/CosyVoice2-0.5B', # 替换为实际模型路径
94
+ load_jit=True,
95
+ load_trt=False,
96
+ fp16=False
97
+ )
98
+
99
+ print("Loading SciQ dataset...")
100
+ dataset = load_dataset("allenai/sciq")
101
+
102
+ # 创建输出目录
103
+ os.makedirs(OUTPUT_DATASET_PATH, exist_ok=True)
104
+
105
+ # ------------------------
106
+ # 主处理循环
107
+ # ------------------------
108
+ final_dataset_dict = {} # 存放各 split 最终处理后的数据
109
+
110
+ for split_name, split_dataset in dataset.items():
111
+ print(f"Processing split: {split_name} with {len(split_dataset)} examples")
112
+ split_output_dir = os.path.join(OUTPUT_DATASET_PATH, split_name)
113
+ os.makedirs(split_output_dir, exist_ok=True)
114
+
115
+ # 用于断点续跑的进度记录
116
+ progress_file = os.path.join(split_output_dir, "progress.txt")
117
+ start_index = 0
118
+ if os.path.exists(progress_file):
119
+ try:
120
+ with open(progress_file, "r") as f:
121
+ start_index = int(f.read().strip())
122
+ print(f"Resuming split '{split_name}' from sample index {start_index}")
123
+ except Exception as e:
124
+ print(f"读取进度文件失败:{e}")
125
+
126
+ final_samples = [] # 用于存储处理后数据
127
+
128
+ # 遍历处理每条样本
129
+ for i in tqdm(range(len(split_dataset)), desc=f"Processing {split_name}"):
130
+ # 如果已处理过,就直接跳过并仅把已存在文件的元信息记入 final_samples
131
+ if i < start_index:
132
+ sample = split_dataset[i]
133
+ wav_path = os.path.join(split_output_dir, f"{i}.wav")
134
+ if os.path.exists(wav_path):
135
+ # 保留所有原始字段 + 音频路径
136
+ sample_dict = {k: sample[k] for k in sample.keys()}
137
+ sample_dict["audio_filepath"] = wav_path
138
+ final_samples.append(sample_dict)
139
+ continue
140
+
141
+ sample = split_dataset[i]
142
+ result = process_example(sample, cosyvoice, common_voice, sample_rate=SAMPLE_RATE)
143
+
144
+ if result is not None:
145
+ audio_tensor = result['audio_tensor']
146
+ if audio_tensor.dim() == 1:
147
+ audio_tensor = audio_tensor.unsqueeze(0)
148
+ sample_rate_val = result['sample_rate']
149
+
150
+ output_wav_path = os.path.join(split_output_dir, f"{i}.wav")
151
+ try:
152
+ torchaudio.save(output_wav_path, audio_tensor, sample_rate_val)
153
+ except Exception as e:
154
+ print(f"Failed to save wav for sample {i}: {e}")
155
+ continue
156
+
157
+ # 保留所有原始字段 + 生成的音频路径
158
+ sample_dict = {k: sample[k] for k in sample.keys()}
159
+ sample_dict["audio_filepath"] = output_wav_path
160
+ final_samples.append(sample_dict)
161
+ else:
162
+ print(f"Sample {i} processing failed, no audio generated.")
163
+
164
+ # 更新进度记录
165
+ with open(progress_file, "w") as f:
166
+ f.write(str(i + 1))
167
+
168
+ # 生成 Hugging Face Dataset 并落盘
169
+ final_dataset_obj = Dataset.from_list(final_samples)
170
+ final_dataset_save_path = os.path.join(split_output_dir, "final_dataset")
171
+ final_dataset_obj.save_to_disk(final_dataset_save_path)
172
+
173
+ print(f"Finished processing split: {split_name} with {len(final_samples)} final samples.")
174
+ final_dataset_dict[split_name] = final_dataset_obj
175
+
176
+ print("所有分割处理完毕,最终数据集已保存。")
r1-a/dataset/shp.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import os # 确保导入 os 用于保存
3
+ from datasets import load_dataset, Dataset
4
+ from tqdm.auto import tqdm # 用于显示进度条
5
+
6
+ # --- 可调整的过滤参数 ---
7
+ # (保持不变)
8
+ SCORE_RATIO_THRESHOLD = 2.0
9
+ MIN_CHOSEN_SCORE = 3
10
+ MIN_HISTORY_WORDS = 10
11
+ MAX_HISTORY_WORDS = 150 # 调整为 150
12
+ MAX_URLS = 0 # 调整为 0
13
+ MAX_NEWLINES = 5
14
+ FORBIDDEN_PATTERNS = [
15
+ r"```.*```",
16
+ r"\|.*\|.*\|",
17
+ ]
18
+ MIN_RESPONSE_WORDS = 10
19
+
20
+ # --- 脚本主逻辑 ---
21
+
22
+ def is_tts_friendly(text):
23
+ """检查文本是否大致适合 TTS"""
24
+ # (保持不变)
25
+ word_count = len(text.split())
26
+ if not (MIN_HISTORY_WORDS <= word_count <= MAX_HISTORY_WORDS):
27
+ return False
28
+ if text.count('http') > MAX_URLS: # 使用调整后的 MAX_URLS
29
+ return False
30
+ if text.count('\n') > MAX_NEWLINES:
31
+ return False
32
+ for pattern in FORBIDDEN_PATTERNS:
33
+ if re.search(pattern, text, re.DOTALL):
34
+ return False
35
+ return True
36
+
37
+ def filter_shp2_train_dataset(dataset_name="stanfordnlp/shp-2"): # 函数名稍作修改以反映其目的
38
+ """
39
+ 加载并过滤 SHP-2 数据集的 'train' split,
40
+ 返回高质量、适合 TTS 的偏好对。
41
+ """
42
+ split_to_process = 'train' # <--- 指定只处理 'train' split
43
+ print(f"加载数据集: {dataset_name}, split: {split_to_process}...")
44
+
45
+ try:
46
+ # --- 修改点 1: 直接加载指定的 split ---
47
+ train_dataset = load_dataset(dataset_name, split=split_to_process)
48
+ print(f"'{split_to_process}' split 加载完成。")
49
+ except Exception as e:
50
+ print(f"错误:无法加载数据集 {dataset_name} 的 '{split_to_process}' split。")
51
+ print(f"错误详情: {e}")
52
+ return [] # 返回空列表表示失败
53
+
54
+ filtered_data = []
55
+ seen_histories = set() # 用于跟踪已经添加的 history,确保唯一性
56
+
57
+ print(f"\n开始处理 '{split_to_process}' split...")
58
+ # --- 修改点 2: 直接迭代加载的 train_dataset ---
59
+ for example in tqdm(train_dataset, desc=f"过滤 {split_to_process} split"):
60
+ history = example.get("history")
61
+ human_ref_A = example.get("human_ref_A")
62
+ human_ref_B = example.get("human_ref_B")
63
+ labels = example.get("labels")
64
+ score_A = example.get("score_A")
65
+ score_B = example.get("score_B")
66
+ score_ratio = example.get("score_ratio")
67
+ domain = example.get("domain")
68
+
69
+ # 基本检查 (保持不变)
70
+ if not all([history, human_ref_A, human_ref_B, labels is not None,
71
+ score_A is not None, score_B is not None, score_ratio is not None, domain]):
72
+ continue
73
+
74
+ # 确保 history 未被处理过 (保持不变)
75
+ if history in seen_histories:
76
+ continue
77
+
78
+ # 确定 chosen 和 reject (保持不变)
79
+ try:
80
+ label_int = int(labels)
81
+ if label_int == 1:
82
+ chosen = human_ref_A
83
+ reject = human_ref_B
84
+ chosen_score = score_A
85
+ elif label_int == 0:
86
+ chosen = human_ref_B
87
+ reject = human_ref_A
88
+ chosen_score = score_B
89
+ else:
90
+ continue
91
+ except (ValueError, TypeError):
92
+ continue
93
+
94
+ # --- 应用过滤条件 (保持不变) ---
95
+ if score_ratio is None or not isinstance(score_ratio, (int, float)) or score_ratio < SCORE_RATIO_THRESHOLD:
96
+ continue
97
+ if chosen_score is None or not isinstance(chosen_score, (int, float)) or chosen_score < MIN_CHOSEN_SCORE:
98
+ continue
99
+ if not is_tts_friendly(history):
100
+ continue
101
+ if len(chosen.split()) < MIN_RESPONSE_WORDS or len(reject.split()) < MIN_RESPONSE_WORDS:
102
+ continue
103
+
104
+ # --- 如果所有过滤条件都通过 (保持不变) ---
105
+ filtered_data.append({
106
+ "query": history,
107
+ "chosen": chosen,
108
+ "reject": reject,
109
+ "domain": domain,
110
+ })
111
+ seen_histories.add(history)
112
+
113
+ print(f"\n过滤完成。从 '{split_to_process}' split 中总共筛选出 {len(filtered_data)} 条高质量样本。")
114
+ return filtered_data
115
+
116
+ # --- 主程序 ---
117
+ if __name__ == "__main__":
118
+ # 执行过滤 (调用修改后的函数)
119
+ filtered_examples = filter_shp2_train_dataset()
120
+
121
+ if filtered_examples:
122
+ # 将过滤后的数据转换为 Hugging Face Dataset 对象 (保持不变)
123
+ filtered_dataset = Dataset.from_list(filtered_examples)
124
+
125
+ # 保存过滤后的数据集 (保持不变)
126
+ output_path = "./shp2_filtered_tts_high_quality_train_only" # 修改输出路径以反映内容
127
+ print(f"正在保存过滤后的训练集数据到: {output_path}")
128
+ # 确保输出目录存在
129
+ os.makedirs(os.path.dirname(output_path), exist_ok=True) # 如果 output_path 是目录,这行不需要
130
+ filtered_dataset.save_to_disk(output_path)
131
+ print("数据集保存完成。")
132
+
133
+ # 打印一些样本看看 (保持不变)
134
+ print("\n部分样本预览:")
135
+ # 从保存的 Dataset 加载并预览,确保保存成功
136
+ try:
137
+ loaded_dataset = Dataset.load_from_disk(output_path)
138
+ for i in range(min(5, len(loaded_dataset))):
139
+ sample = loaded_dataset[i]
140
+ print(f"--- 样本 {i+1} ---")
141
+ print(f"Domain: {sample['domain']}")
142
+ print(f"Query: {sample['query'][:200]}...")
143
+ print(f"Chosen: {sample['chosen'][:200]}...")
144
+ except Exception as e:
145
+ print(f"加载预览样本时出错: {e}") # 增加错误处理
146
+
147
+ else:
148
+ print("没有找到符合条件的样本,请检查过滤参数设置或确认 'train' split 是否存在且包含数据。")
r1-a/dataset/shp_tts.py ADDED
@@ -0,0 +1,494 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --- SET CUDA DEVICE ---
2
+ # Method 1: Set environment variable BEFORE importing torch/cosyvoice
3
+ # This makes only GPU 1 visible to the script, appearing as 'cuda:0' internally.
4
+ import os
5
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
6
+ # --- End CUDA Device Setting ---
7
+
8
+ import random
9
+ import torch
10
+ import torchaudio
11
+ # Make sure necessary types are imported
12
+ from datasets import load_dataset, Dataset, load_from_disk, Features, Value
13
+ import sys
14
+ from tqdm import tqdm
15
+ import time
16
+ import shutil # Added for potentially removing old dataset save dirs
17
+
18
+ # Check if the specified GPU is available after setting the environment variable
19
+ if not torch.cuda.is_available():
20
+ print("ERROR: CUDA is not available after setting CUDA_VISIBLE_DEVICES='1'. Check your PyTorch installation, GPU drivers, and that GPU 1 exists and is functional.")
21
+ # Force exit if the intended GPU is not found
22
+ sys.exit(1)
23
+ else:
24
+ # Since CUDA_VISIBLE_DEVICES is set to '1', the first *visible* device is cuda:0
25
+ effective_device = torch.device("cuda:0")
26
+ try:
27
+ print(f"CUDA device visible to PyTorch: {torch.cuda.get_device_name(0)}") # Should show the name of GPU 1
28
+ print(f"Script will effectively run on: {effective_device}")
29
+ # Perform a small check to ensure the device is usable
30
+ _ = torch.tensor([1.0]).to(effective_device)
31
+ print("Device check successful.")
32
+ except Exception as e:
33
+ print(f"ERROR: Failed CUDA device check for visible device 'cuda:0' (original GPU 1): {e}")
34
+ sys.exit(1)
35
+
36
+
37
+ # Ensure CosyVoice path is correct
38
+ COSYVOICE_PATH = '/home/chenyifu/CosyVoice' # Make sure this path is correct
39
+ if not os.path.isdir(COSYVOICE_PATH):
40
+ print(f"ERROR: CosyVoice path not found: {COSYVOICE_PATH}")
41
+ sys.exit(1)
42
+ sys.path.append(COSYVOICE_PATH)
43
+
44
+ # Import CosyVoice *after* setting the environment variable
45
+ try:
46
+ from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
47
+ from cosyvoice.utils.file_utils import load_wav
48
+ print("CosyVoice imported successfully.")
49
+ except ImportError as e:
50
+ print(f"Error importing CosyVoice: {e}")
51
+ print(f"Please ensure the path '{COSYVOICE_PATH}' is correct and the library is installed within that directory.")
52
+ sys.exit(1)
53
+ except Exception as e:
54
+ print(f"An unexpected error occurred during CosyVoice import: {e}")
55
+ sys.exit(1)
56
+
57
+ # ------------------------
58
+ # 配置参数 (MODIFIED FOR NEW DATASET)
59
+ # ------------------------
60
+ COMMON_VOICE_LANGUAGE = "en" # Language for prompts
61
+
62
+ # --- !! MODIFIED !! ---
63
+ # Input: Path to the dataset created by the previous selection script
64
+ INPUT_DATASET_PATH = "./shp2_final_top20_percent/train_split_top20_percent_by_complexity"
65
+ # Output: Directory to save new audio files and the final dataset object
66
+ OUTPUT_DATASET_PATH = './shp2_top20_percent_with_query_audio'
67
+ # --- End MODIFIED ---
68
+
69
+ SAMPLE_RATE = 16000 # Target sample rate for TTS output (should match CosyVoice default)
70
+ MAX_TTS_RETRIES = 3
71
+ RETRY_DELAY_SECONDS = 3 # Slightly increased delay
72
+
73
+ # ------------------------
74
+ # 辅助函数 (GPU handling and core TTS logic - UNCHANGED as requested)
75
+ # ------------------------
76
+ def get_random_prompt(common_voice_dataset, sample_rate=SAMPLE_RATE):
77
+ """
78
+ 从 VoxPopuli 数据集中随机抽取一条语音及对应文本作为 prompt。
79
+ (Logic remains unchanged)
80
+ """
81
+ idx = random.randint(0, len(common_voice_dataset) - 1)
82
+ try:
83
+ # Use select().with_format('numpy') for potentially better memory handling with large datasets
84
+ sample = common_voice_dataset.select([idx]).with_format('numpy')[0]
85
+ audio = sample['audio']
86
+ waveform = torch.tensor(audio['array'], dtype=torch.float32) # Created on CPU
87
+ sr = audio['sampling_rate']
88
+
89
+ if sr != sample_rate:
90
+ # Ensure waveform is 1D before resampling
91
+ if waveform.dim() > 1:
92
+ waveform = waveform.mean(dim=0)
93
+ if waveform.dim() != 1:
94
+ print(f"Warning: Unexpected waveform dimension {waveform.dim()} before resampling. Skipping prompt.")
95
+ return get_random_prompt(common_voice_dataset, sample_rate) # Retry
96
+
97
+ resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate)
98
+ waveform = resampler(waveform)
99
+
100
+ # Ensure output is 2D [1, T]
101
+ if waveform.dim() == 1:
102
+ waveform = waveform.unsqueeze(0)
103
+ elif waveform.dim() > 2:
104
+ print(f"Warning: Unexpected waveform dimension {waveform.dim()} after resampling. Skipping prompt.")
105
+ return get_random_prompt(common_voice_dataset, sample_rate) # Retry
106
+
107
+ raw_text = sample.get('raw_text', '')
108
+ if waveform.numel() == 0 or not raw_text or not raw_text.strip():
109
+ # print("Warning: Got an empty audio or text prompt, trying again...")
110
+ return get_random_prompt(common_voice_dataset, sample_rate) # Retry
111
+
112
+ # Return CPU tensor, CosyVoice inference should handle moving it
113
+ return waveform, raw_text
114
+ except Exception as e:
115
+ print(f"Error getting random prompt at index {idx}: {e}. Retrying...")
116
+ time.sleep(0.1) # Small delay before retry
117
+ return get_random_prompt(common_voice_dataset, sample_rate)
118
+
119
+ def text_to_audio(text_to_convert, cosyvoice, common_voice_dataset, stream=False, max_retries=MAX_TTS_RETRIES):
120
+ """
121
+ 利用 CosyVoice2 模型将输入文本转换为语音,采用随机 prompt 进行零样本推理。
122
+ Includes retry logic on failure. Assumes cosyvoice runs on the configured device.
123
+ (Logic remains unchanged)
124
+ """
125
+ last_exception = None
126
+ prompt_speech = None
127
+ prompt_text = "N/A"
128
+
129
+ for attempt in range(max_retries):
130
+ try:
131
+ # Get prompt - ensures it's valid this time
132
+ prompt_speech, prompt_text = get_random_prompt(common_voice_dataset, sample_rate=SAMPLE_RATE)
133
+ # prompt_speech is initially on CPU
134
+
135
+ all_speech = []
136
+ # cosyvoice.inference_zero_shot should internally use the GPU device it was initialized on
137
+ # (which should be the visible cuda:0, i.e., original cuda:1)
138
+ inference_generator = cosyvoice.inference_zero_shot(
139
+ text_to_convert,
140
+ prompt_text,
141
+ prompt_speech, # Pass CPU tensor
142
+ stream=stream,
143
+ text_frontend=False # Assuming default frontend is desired
144
+ )
145
+ # Generated chunks 'tts_speech' will be on the GPU
146
+ for i, chunk in enumerate(inference_generator):
147
+ if chunk is None:
148
+ print(f"Warning: Received None chunk {i} during TTS generation for text '{text_to_convert[:60]}...'")
149
+ continue
150
+ if 'tts_speech' in chunk and chunk['tts_speech'] is not None and chunk['tts_speech'].numel() > 0:
151
+ # Ensure chunk is on the correct device (should be already, but belt-and-suspenders)
152
+ gpu_chunk = chunk['tts_speech'].to(effective_device)
153
+ all_speech.append(gpu_chunk)
154
+ # else: # Reduce log noise
155
+ # print(f"Debug: Chunk {i} missing 'tts_speech' or is empty for text '{text_to_convert[:60]}...'")
156
+
157
+
158
+ if not all_speech:
159
+ # Clear GPU memory cache if an error occurs during generation
160
+ if torch.cuda.is_available(): torch.cuda.empty_cache()
161
+ raise ValueError("TTS inference finished but produced no valid audio chunks.")
162
+
163
+ # combined_speech is on GPU
164
+ combined_speech = torch.cat(all_speech, dim=-1)
165
+ sample_rate_val = cosyvoice.sample_rate
166
+
167
+ # --- Add a check for silence ---
168
+ # Check max absolute amplitude; threshold might need tuning
169
+ if torch.max(torch.abs(combined_speech)) < 0.001:
170
+ print(f"Warning: Generated audio appears to be silent for text '{text_to_convert[:60]}...'. Retrying...")
171
+ raise ValueError("Generated audio is silent")
172
+
173
+
174
+ return {
175
+ # Return GPU tensor, will be moved to CPU before saving
176
+ 'audio_tensor': combined_speech,
177
+ 'sample_rate': sample_rate_val
178
+ }
179
+
180
+ except Exception as e:
181
+ last_exception = e
182
+ print(f"Error converting text to audio on attempt {attempt + 1}/{max_retries}: {e}")
183
+ print(f" Text: '{text_to_convert[:100]}...'")
184
+ print(f" Prompt Text Used: '{prompt_text[:100]}...'")
185
+ # Clear GPU cache on error
186
+ if torch.cuda.is_available(): torch.cuda.empty_cache()
187
+ if attempt < max_retries - 1:
188
+ print(f" Retrying with a different prompt in {RETRY_DELAY_SECONDS}s...")
189
+ time.sleep(RETRY_DELAY_SECONDS)
190
+ else:
191
+ print(f" All {max_retries} TTS attempts failed.")
192
+
193
+ print(f"Failed to generate audio for text after {max_retries} attempts: '{text_to_convert[:60]}...'")
194
+ if last_exception:
195
+ print(f"Last error: {last_exception}")
196
+ # Explicitly return None on failure
197
+ return None
198
+
199
+ # --- !! MODIFIED process_example !! ---
200
+ def process_example(example, cosyvoice, common_voice_dataset, sample_rate=SAMPLE_RATE):
201
+ """
202
+ 针对从磁盘加载的 *SHP-2 Top 20%* 数据集中的单个样本进行 TTS 处理。
203
+ Processes the example['query'] field.
204
+ """
205
+ # --- MODIFIED: Target the 'query' field ---
206
+ text_to_convert = example.get('query')
207
+ # --- End MODIFIED ---
208
+
209
+ if not text_to_convert or not isinstance(text_to_convert, str) or text_to_convert.strip() == "":
210
+ # --- MODIFIED: Update warning message ---
211
+ print(f"Warning: Skipping example due to missing or empty 'query' field. Keys: {list(example.keys())}")
212
+ # --- End MODIFIED ---
213
+ return None
214
+
215
+ # Call the unchanged text_to_audio function
216
+ audio_result = text_to_audio(text_to_convert, cosyvoice, common_voice_dataset, stream=False)
217
+
218
+ if audio_result is not None:
219
+ audio_tensor = audio_result['audio_tensor'] # Still on GPU here
220
+ # Basic validation of the tensor
221
+ if audio_tensor is None or audio_tensor.numel() == 0:
222
+ print(f"Warning: TTS process returned empty tensor for query: '{text_to_convert[:60]}...'")
223
+ return None
224
+
225
+ # Ensure correct shape (should be [1, T] from text_to_audio)
226
+ if audio_tensor.dim() == 1:
227
+ audio_tensor = audio_tensor.unsqueeze(0)
228
+ elif audio_tensor.dim() > 2:
229
+ print(f"Warning: Generated audio tensor has unexpected shape {audio_tensor.shape}. Attempting to flatten.")
230
+ audio_tensor = audio_tensor.view(1, -1) # Flatten to [1, T]
231
+
232
+ # Double-check for emptiness after potential reshape
233
+ if audio_tensor.numel() == 0:
234
+ print(f"Warning: Generated audio tensor became empty after reshape for query: '{text_to_convert[:60]}...'")
235
+ return None
236
+
237
+ return {
238
+ 'audio_tensor': audio_tensor, # Return GPU tensor
239
+ 'sample_rate': audio_result['sample_rate']
240
+ }
241
+ else:
242
+ # text_to_audio already prints detailed errors
243
+ return None
244
+
245
+ # ------------------------
246
+ # 数据加载与模型初始化
247
+ # ------------------------
248
+ print("Loading VoxPopuli (as Common Voice) dataset for prompts...")
249
+ try:
250
+ # Load prompt dataset to CPU memory
251
+ common_voice = load_dataset("facebook/voxpopuli", COMMON_VOICE_LANGUAGE, split='train', trust_remote_code=True)
252
+ # Filter potentially problematic samples (optional, but can help)
253
+ common_voice = common_voice.filter(lambda x: x['audio'] is not None and x['audio']['array'] is not None and x['raw_text'] is not None and len(x['raw_text'].strip()) > 5 and len(x['audio']['array']) > SAMPLE_RATE * 0.5) # Min 0.5 sec, non-empty text
254
+ print(f"Loaded and filtered VoxPopuli '{COMMON_VOICE_LANGUAGE}' samples: {len(common_voice)}")
255
+ if len(common_voice) == 0:
256
+ raise ValueError(f"VoxPopuli dataset '{COMMON_VOICE_LANGUAGE}' loaded but contains no valid samples after filtering.")
257
+ except Exception as e:
258
+ print(f"Error loading or filtering VoxPopuli dataset: {e}")
259
+ sys.exit(1)
260
+
261
+
262
+ print("Initializing CosyVoice2 model...")
263
+ try:
264
+ # CosyVoice should automatically initialize on the visible device ('cuda:0', which is original 'cuda:1')
265
+ cosyvoice_model_path = os.path.join(COSYVOICE_PATH, 'pretrained_models/CosyVoice2-0.5B')
266
+ if not os.path.isdir(cosyvoice_model_path):
267
+ print(f"ERROR: CosyVoice pretrained model directory not found: {cosyvoice_model_path}")
268
+ sys.exit(1)
269
+
270
+ cosyvoice = CosyVoice2(
271
+ cosyvoice_model_path,
272
+ load_jit=True, # Assuming JIT is preferred
273
+ load_trt=False, # Ensure TRT is False if not set up for GPU 1
274
+ fp16=False # Keep FP16 False unless GPU 1 is known to handle it well and has enough VRAM
275
+ # device=effective_device # Usually not needed if CUDA_VISIBLE_DEVICES is set
276
+ )
277
+ print(f"CosyVoice model initialized. Target device: {effective_device}")
278
+ # Verify model is on the correct device (optional check)
279
+ if hasattr(cosyvoice, 'model') and hasattr(cosyvoice.model, 'device'):
280
+ print(f"CosyVoice internal model device: {cosyvoice.model.device}")
281
+ elif hasattr(cosyvoice, 'device'):
282
+ print(f"CosyVoice main object device: {cosyvoice.device}")
283
+
284
+ except Exception as e:
285
+ print(f"Error initializing CosyVoice2 model: {e}")
286
+ if isinstance(e, RuntimeError) and 'CUDA' in str(e):
287
+ print("This might be a CUDA initialization error. Ensure GPU 1 is functional, has enough memory, and required CUDA toolkit versions are compatible.")
288
+ sys.exit(1)
289
+
290
+ # --- !! MODIFIED Dataset Loading !! ---
291
+ print(f"\nLoading the target dataset from disk: {INPUT_DATASET_PATH}")
292
+ if not os.path.exists(INPUT_DATASET_PATH):
293
+ print(f"Error: Input dataset directory not found at '{INPUT_DATASET_PATH}'.")
294
+ print("Please ensure the previous (selection) script ran successfully and produced the dataset at this location.")
295
+ sys.exit(1)
296
+
297
+ try:
298
+ input_dataset = load_from_disk(INPUT_DATASET_PATH)
299
+
300
+ print(f"Successfully loaded dataset with {len(input_dataset)} examples.")
301
+ if len(input_dataset) == 0:
302
+ print("Error: The loaded dataset is empty. Cannot proceed.")
303
+ sys.exit(1)
304
+ # Store original features to reconstruct the final dataset correctly
305
+ original_features = input_dataset.features
306
+ print(f"Original features: {original_features}")
307
+ # Check for 'query' column existence
308
+ if 'query' not in original_features:
309
+ print(f"Error: The loaded dataset from '{INPUT_DATASET_PATH}' does not contain the required 'query' column.")
310
+ sys.exit(1)
311
+
312
+ except Exception as e:
313
+ print(f"Error loading dataset from '{INPUT_DATASET_PATH}': {e}")
314
+ sys.exit(1)
315
+ # --- End MODIFIED Dataset Loading ---
316
+
317
+
318
+ # --- Create output directories ---
319
+ os.makedirs(OUTPUT_DATASET_PATH, exist_ok=True)
320
+ # Subdirectory for the actual audio files
321
+ audio_output_dir = os.path.join(OUTPUT_DATASET_PATH, "audio_files")
322
+ os.makedirs(audio_output_dir, exist_ok=True)
323
+ print(f"Audio files will be saved in: {audio_output_dir}")
324
+ # Path for the progress tracking file
325
+ progress_file = os.path.join(OUTPUT_DATASET_PATH, "progress.txt")
326
+ print(f"Progress will be tracked in: {progress_file}")
327
+
328
+
329
+ # ------------------------
330
+ # 主处理循环 (MODIFIED FOR SINGLE DATASET)
331
+ # ------------------------
332
+ print(f"\nStarting TTS processing for {len(input_dataset)} samples...")
333
+
334
+ start_index = 0
335
+ # Read progress file to resume if necessary
336
+ if os.path.exists(progress_file):
337
+ try:
338
+ with open(progress_file, "r") as f:
339
+ content = f.read().strip()
340
+ if content:
341
+ start_index = int(content)
342
+ print(f"Resuming TTS processing from sample index {start_index}")
343
+ else:
344
+ print(f"Progress file '{progress_file}' is empty, starting TTS from index 0.")
345
+ start_index = 0
346
+ except ValueError:
347
+ print(f"Could not parse integer from progress file '{progress_file}'. Starting TTS from index 0.")
348
+ start_index = 0
349
+ except Exception as e:
350
+ print(f"Error reading progress file '{progress_file}': {e}. Starting TTS from index 0.")
351
+ start_index = 0
352
+
353
+ # List to hold dictionaries for the final dataset
354
+ final_samples = []
355
+
356
+ # --- Main Loop ---
357
+ pbar = tqdm(range(start_index, len(input_dataset)), desc=f"TTS on 'query' field", initial=start_index, total=len(input_dataset))
358
+ for i in pbar:
359
+ sample = input_dataset[i] # Get sample dictionary (on CPU)
360
+
361
+ # Define unique output WAV path using the index
362
+ # Using index is simple, assumes dataset order is stable during processing
363
+ output_wav_filename = f"query_{i}.wav"
364
+ output_wav_path = os.path.join(audio_output_dir, output_wav_filename)
365
+
366
+ # --- Check if audio file already exists ---
367
+ if os.path.exists(output_wav_path):
368
+ # If already processed, create the dict for the final dataset
369
+ sample_dict = dict(sample) # Copy original data
370
+ sample_dict["query_audio_filepath"] = output_wav_path # Add the path field
371
+ final_samples.append(sample_dict)
372
+ # Update progress file even when skipping (to ensure it reflects the latest processed/checked index)
373
+ with open(progress_file, "w") as f:
374
+ f.write(str(i + 1))
375
+ continue # Skip TTS for this sample
376
+
377
+ # --- Perform TTS on the target device ---
378
+ # process_example handles getting the 'query' text and calling text_to_audio
379
+ result = process_example(sample, cosyvoice, common_voice, sample_rate=SAMPLE_RATE)
380
+
381
+ if result is not None and 'audio_tensor' in result and result['audio_tensor'] is not None:
382
+ audio_tensor = result['audio_tensor'] # Received tensor is on GPU
383
+ sample_rate_val = result['sample_rate']
384
+
385
+ try:
386
+ # --- Move tensor to CPU before saving ---
387
+ audio_tensor_save = audio_tensor.detach().cpu().to(torch.float32)
388
+
389
+ # Ensure shape is 2D [1, T] for torchaudio.save
390
+ if audio_tensor_save.dim() == 1:
391
+ audio_tensor_save = audio_tensor_save.unsqueeze(0)
392
+ elif audio_tensor_save.dim() > 2:
393
+ print(f"Warning: Flattening unexpected tensor shape {audio_tensor_save.shape} before saving.")
394
+ audio_tensor_save = audio_tensor_save.view(1, -1)
395
+
396
+ # Save the audio file
397
+ torchaudio.save(output_wav_path, audio_tensor_save, sample_rate_val)
398
+
399
+ # Create dict for the final dataset including the new path
400
+ sample_dict = dict(sample) # Copy original data
401
+ sample_dict["query_audio_filepath"] = output_wav_path # Add the path field
402
+ final_samples.append(sample_dict)
403
+
404
+ # --- Explicitly delete GPU tensor ---
405
+ del audio_tensor
406
+ # No need to delete audio_tensor_save as it's on CPU
407
+
408
+ except Exception as e:
409
+ print(f"Failed to save wav for sample {i} ('query' field TTS) at {output_wav_path}: {e}")
410
+ # Attempt to remove partially saved/corrupted file if save failed
411
+ if os.path.exists(output_wav_path):
412
+ try: os.remove(output_wav_path)
413
+ except OSError: pass
414
+ # Clear cache on save error too
415
+ if torch.cuda.is_available(): torch.cuda.empty_cache()
416
+ else:
417
+ # Log failure (process_example or text_to_audio already logged details)
418
+ query_text = sample.get('query', 'N/A')
419
+ print(f"Sample {i} TTS failed or produced no audio after retries (Query Text: '{query_text[:60]}...'). Audio file not saved.")
420
+ # Ensure cache is cleared even on TTS failure
421
+ if torch.cuda.is_available(): torch.cuda.empty_cache()
422
+
423
+
424
+ # --- Update progress file ---
425
+ # Write the index of the *next* sample to start from if resuming
426
+ with open(progress_file, "w") as f:
427
+ f.write(str(i + 1))
428
+
429
+ # --- Optional: Periodic cache clearing ---
430
+ if i > 0 and i % 50 == 0: # Example: clear cache every 50 iterations (adjust as needed)
431
+ if torch.cuda.is_available():
432
+ # print(f"Clearing CUDA cache at iteration {i}...") # Debug log
433
+ torch.cuda.empty_cache()
434
+
435
+
436
+ # --- Final cache clear after finishing the loop ---
437
+ if torch.cuda.is_available():
438
+ print("Clearing final CUDA cache...")
439
+ torch.cuda.empty_cache()
440
+
441
+ # ------------------------
442
+ # 保存最终数据集 (MODIFIED)
443
+ # ------------------------
444
+ print("\nTTS processing loop finished.")
445
+ if final_samples:
446
+ print(f"Successfully processed (or skipped existing) {len(final_samples)} samples.")
447
+
448
+ # --- Define features for the new dataset ---
449
+ # Start with original features and add the new audio path column
450
+ new_features_dict = original_features.copy()
451
+ new_column_name = 'query_audio_filepath'
452
+ if new_column_name in new_features_dict:
453
+ print(f"Warning: Feature '{new_column_name}' already exists in original features. Overwriting.")
454
+ new_features_dict[new_column_name] = Value('string') # Add the new column definition
455
+ try:
456
+ new_features = Features(new_features_dict)
457
+ print(f"Defined new features for saving: {new_features}")
458
+
459
+ # --- Create the final Dataset object ---
460
+ print("Creating final Dataset object from processed samples...")
461
+ final_dataset_obj = Dataset.from_list(final_samples, features=new_features)
462
+
463
+ # --- Define path to save the final dataset metadata object ---
464
+ # This object contains the original data + the new filepath column
465
+ final_dataset_save_path = os.path.join(OUTPUT_DATASET_PATH, "processed_dataset_with_audio")
466
+ print(f"Saving final dataset metadata (with audio paths) to: {final_dataset_save_path}...")
467
+
468
+ # Ensure the target directory exists and is empty before saving
469
+ if os.path.exists(final_dataset_save_path):
470
+ print(f"Removing existing directory before saving: {final_dataset_save_path}")
471
+ shutil.rmtree(final_dataset_save_path)
472
+ # The save_to_disk function will create the directory
473
+ # os.makedirs(os.path.dirname(final_dataset_save_path), exist_ok=True) # Not needed if saving to the dir itself
474
+
475
+ final_dataset_obj.save_to_disk(final_dataset_save_path)
476
+ print(f"Final dataset object saved successfully.")
477
+
478
+ except Exception as e:
479
+ print(f"\n!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
480
+ print(f"Error during final dataset creation or saving: {e}")
481
+ print(f"Audio files might be saved in '{audio_output_dir}', but the final dataset object could not be created/saved.")
482
+ print(f"Check the features and the content of 'final_samples'.")
483
+ print(f"!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
484
+
485
+ else:
486
+ print("Processing finished, but no samples were successfully processed or had existing audio files.")
487
+ print(f"Check logs for TTS errors. Audio files directory: '{audio_output_dir}'.")
488
+
489
+
490
+ print("\n" + "="*60)
491
+ print(f"Script finished.")
492
+ print(f"Generated audio files are located in: '{audio_output_dir}'")
493
+ print(f"The final dataset (metadata + audio file paths) is saved at: '{os.path.join(OUTPUT_DATASET_PATH, 'processed_dataset_with_audio')}' (if saving was successful)")
494
+ print("="*60)
r1-a/dataset/ultrachat.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import os
3
+ from datasets import load_dataset, Dataset
4
+ from tqdm.auto import tqdm
5
+ import json
6
+ import string # 引入 string 模块用于字符检查
7
+
8
+ # --- 可调整的过滤参数 ---
9
+ MIN_USER_QUERY_WORDS = 5
10
+ MAX_USER_QUERY_WORDS = 150
11
+ SIMPLE_PROMPT_PATTERNS = [
12
+ r"^\s*(ok|yes|no|thanks?|got it|great|cool|sounds good|perfect|alright|fine|bye|goodbye)[.!\s]*$",
13
+ r"^\s*\?+\s*$",
14
+ r"^\s*i see\.?\s*$",
15
+ r"^\s*you'?re welcome\.?\s*$", # 增加一些简单回应
16
+ r"^\s*okay then\.?\s*$",
17
+ ]
18
+ CORRUPTED_ENDINGS = [" user", " assistan"]
19
+ MAX_QUERY_URLS = 0
20
+ MAX_QUERY_NEWLINES = 3
21
+ MIN_DIALOGUE_TURNS = 2 # 对 messages 列表的长度要求
22
+
23
+ # --- 新增:代码和 TTS 不友好内容过滤参数 ---
24
+ FILTER_CODE_KEYWORDS = True # 是否过滤包含常见代码关键字的查询
25
+ CODE_KEYWORDS_PATTERN = r"\b(def|class|import|function|const|let|var|public|private|static|void|main|int|float|str|bool|return|yield|async|await|try|except|finally|if|else|for|while|switch|case|break|continue|lambda|map|filter|reduce|numpy|pandas|torch|tensorflow|react|angular|vue|console\.log|System\.out\.println)\b" # 常见编程关键字 (可扩展)
26
+
27
+ FILTER_INLINE_CODE = True # 是否过滤包含 Markdown 行内代码 `...` 的查询
28
+ INLINE_CODE_PATTERN = r"`[^`]+`"
29
+
30
+ FILTER_MARKDOWN_TABLE_SEP = True # 是否过滤包含 Markdown 表格分隔符 `|---|`
31
+ MARKDOWN_TABLE_SEP_PATTERN = r"\|-+\|"
32
+
33
+ FILTER_EXCESSIVE_SPECIAL_CHARS = True # 是否过滤特殊字符比例过高的查询
34
+ MAX_SPECIAL_CHAR_RATIO = 0.25 # 特殊字符(非字母、数字、空格)允许的最大比例
35
+
36
+ FILTER_LONG_STRINGS_NO_SPACE = True # 是否过滤包含过长无空格字符串的查询
37
+ MAX_NO_SPACE_STRING_LEN = 50 # 无空格字符串的最大允许长度
38
+
39
+ QUERY_FORBIDDEN_PATTERNS = [
40
+ r"```", # 代码块标记 (已有)
41
+ # r"\|.*\|.*\|", # 简单的表格行检测 (可能过于宽泛,用下面的分隔符检测可能更好)
42
+ # 新增模式会根据上面的开关动态添加
43
+ ]
44
+
45
+ # --- 脚本主逻辑 ---
46
+
47
+ def is_potentially_garbled(text):
48
+ if not text or not isinstance(text, str): return True
49
+ for ending in CORRUPTED_ENDINGS:
50
+ if text.endswith(ending): return True
51
+ # 稍微放宽括号检查,只检查严重不平衡的情况
52
+ if text.count('{') > text.count('}') + 2 or text.count('[') > text.count(']') + 2: return True
53
+ if text.count('```') % 2 != 0: return True # 未闭合的代码块
54
+ return False
55
+
56
+ def is_prompt_suitable(text, turn_index): # 添加 turn_index 用于调试
57
+ """检查用户提问是否符合质量和 TTS 要求"""
58
+ if not text or not isinstance(text, str):
59
+ # print(f"DEBUG: Prompt rejected (turn {turn_index}): Not text or empty")
60
+ return False
61
+
62
+ # --- 基本检查 ---
63
+ word_count = len(text.split())
64
+ if not (MIN_USER_QUERY_WORDS <= word_count <= MAX_USER_QUERY_WORDS):
65
+ # print(f"DEBUG: Prompt rejected (turn {turn_index}): Word count {word_count} out of range [{MIN_USER_QUERY_WORDS}, {MAX_USER_QUERY_WORDS}]")
66
+ return False
67
+
68
+ text_stripped = text.strip()
69
+ for pattern in SIMPLE_PROMPT_PATTERNS:
70
+ if re.fullmatch(pattern, text_stripped, re.IGNORECASE):
71
+ # print(f"DEBUG: Prompt rejected (turn {turn_index}): Matched simple pattern '{pattern}'")
72
+ return False
73
+
74
+ if text.count('http') > MAX_QUERY_URLS:
75
+ # print(f"DEBUG: Prompt rejected (turn {turn_index}): Too many URLs")
76
+ return False
77
+ if text.count('\n') > MAX_QUERY_NEWLINES:
78
+ # print(f"DEBUG: Prompt rejected (turn {turn_index}): Too many newlines")
79
+ return False
80
+
81
+ # --- 通用禁止模式检查 (包括原有的和动态添加的) ---
82
+ current_forbidden_patterns = list(QUERY_FORBIDDEN_PATTERNS) # 复制基础列表
83
+ if FILTER_MARKDOWN_TABLE_SEP:
84
+ current_forbidden_patterns.append(MARKDOWN_TABLE_SEP_PATTERN)
85
+
86
+ for pattern in current_forbidden_patterns:
87
+ # 使用 re.DOTALL 使 . 匹配换行符, re.IGNORECASE 对某些模式可能有用 (比如关键词)
88
+ search_flags = re.DOTALL
89
+ if pattern == CODE_KEYWORDS_PATTERN: # 关键词需要忽略大小写
90
+ search_flags |= re.IGNORECASE
91
+ if re.search(pattern, text, search_flags):
92
+ # print(f"DEBUG: Prompt rejected (turn {turn_index}): Matched forbidden pattern '{pattern}'")
93
+ return False
94
+
95
+ # --- 新增:特定代码和 TTS 不友好内容的检查 ---
96
+
97
+ # 1. 检查常见代码关键字
98
+ if FILTER_CODE_KEYWORDS and re.search(CODE_KEYWORDS_PATTERN, text, re.IGNORECASE):
99
+ # print(f"DEBUG: Prompt rejected (turn {turn_index}): Contains code keywords")
100
+ return False
101
+
102
+ # 2. 检查 Markdown 行内代码
103
+ if FILTER_INLINE_CODE and re.search(INLINE_CODE_PATTERN, text):
104
+ # print(f"DEBUG: Prompt rejected (turn {turn_index}): Contains inline code")
105
+ return False
106
+
107
+ # 3. 检查过长的无空格字符串 (可能为哈希、base64、代码片段等)
108
+ if FILTER_LONG_STRINGS_NO_SPACE:
109
+ # \S 匹配任何非空白字符
110
+ if re.search(r"\S{" + str(MAX_NO_SPACE_STRING_LEN) + r",}", text):
111
+ # print(f"DEBUG: Prompt rejected (turn {turn_index}): Contains long string without spaces (>{MAX_NO_SPACE_STRING_LEN})")
112
+ return False
113
+
114
+ # 4. 检查特殊字符(非字母、数字、空格)的比例
115
+ if FILTER_EXCESSIVE_SPECIAL_CHARS and len(text) > 0: # 避免除以零
116
+ special_chars = 0
117
+ total_chars = len(text)
118
+ for char in text:
119
+ # string.punctuation 包含常用标点
120
+ # 我们也排除字母、数字和空格,剩下的算作特殊字符
121
+ if not char.isalnum() and not char.isspace():
122
+ special_chars += 1
123
+ ratio = special_chars / total_chars
124
+ if ratio > MAX_SPECIAL_CHAR_RATIO:
125
+ # print(f"DEBUG: Prompt rejected (turn {turn_index}): Excessive special characters ratio ({ratio:.2f} > {MAX_SPECIAL_CHAR_RATIO})")
126
+ return False
127
+
128
+ # --- 最终 Garbled 检查 ---
129
+ if is_potentially_garbled(text):
130
+ # print(f"DEBUG: Prompt rejected (turn {turn_index}): Potentially garbled")
131
+ return False
132
+
133
+ # print(f"DEBUG: Prompt accepted (turn {turn_index})")
134
+ return True
135
+
136
+ # --- format_history 函数保持不变 ---
137
+ def format_history(history_list):
138
+ """将历史消息列表格式化为文本"""
139
+ if not history_list:
140
+ return ""
141
+ formatted = []
142
+ for msg in history_list:
143
+ role_tag = "[USER]" if msg.get('role') == 'user' else "[ASSISTANT]"
144
+ content = msg.get('content', '')
145
+ formatted.append(f"{role_tag}\n{content}")
146
+ return "\n\n".join(formatted)
147
+
148
+
149
+ # --- filter_ultrachat_dataset_v2 函数保持不变 (除了调用更新后的 is_prompt_suitable) ---
150
+ def filter_ultrachat_dataset_v2(dataset_name="HuggingFaceH4/ultrachat_200k", split="train_sft"):
151
+ """
152
+ 加载并过滤 UltraChat 数据集 (根据截图修正结构访问)。
153
+ 使用更新后的 is_prompt_suitable 进行过滤。
154
+ """
155
+ print(f"加载数据集: {dataset_name}, split: {split}...")
156
+ try:
157
+ dataset = load_dataset(dataset_name, split=split)
158
+ print(f"'{split}' split 加载完成。")
159
+ except Exception as e:
160
+ print(f"错误:无法加载数据集 {dataset_name} 的 '{split}' split。")
161
+ print(f"错误详情: {e}")
162
+ return []
163
+
164
+ filtered_samples = []
165
+ processed_dialogues = 0
166
+ extracted_samples = 0
167
+ skipped_garbled_dialogue = 0
168
+ skipped_short_dialogue = 0
169
+ skipped_bad_format = 0
170
+
171
+ print(f"\n开始处理 '{split}' split 中的对话...")
172
+ for dialogue in tqdm(dataset, desc="处理对话"):
173
+ processed_dialogues += 1
174
+
175
+ messages = dialogue.get("messages")
176
+ prompt_id = dialogue.get("prompt_id")
177
+ initial_prompt = dialogue.get("prompt")
178
+
179
+ if not prompt_id: continue
180
+ if not messages or not isinstance(messages, list):
181
+ skipped_bad_format += 1
182
+ continue
183
+ if len(messages) < MIN_DIALOGUE_TURNS:
184
+ skipped_short_dialogue += 1
185
+ continue
186
+
187
+ dialogue_seems_garbled = False
188
+ for msg in messages:
189
+ content = msg.get("content")
190
+ # 对话级损坏检查现在仅基于 is_potentially_garbled
191
+ if is_potentially_garbled(content):
192
+ dialogue_seems_garbled = True
193
+ break
194
+ if dialogue_seems_garbled:
195
+ skipped_garbled_dialogue += 1
196
+ continue
197
+
198
+ current_history_list = []
199
+ for i, message in enumerate(messages):
200
+ role = message.get("role")
201
+ content = message.get("content", "").strip()
202
+
203
+ if not role or not content:
204
+ continue
205
+
206
+ if role == "user":
207
+ # 调用更新后的过滤函数
208
+ if is_prompt_suitable(content, i):
209
+ history_text = format_history(current_history_list)
210
+ filtered_samples.append({
211
+ "dialogue_id": prompt_id,
212
+ "turn_index": i,
213
+ "query": content,
214
+ "history": history_text
215
+ })
216
+ extracted_samples += 1
217
+ # else: # 取消注释内部打印以查看拒绝原因
218
+ # pass
219
+
220
+ current_history_list.append({"role": role, "content": content})
221
+
222
+ print(f"\n过滤完成。")
223
+ print(f"处理对话数: {processed_dialogues}")
224
+ print(f"因格式错误跳过: {skipped_bad_format}")
225
+ print(f"因 messages 列表过短 (<{MIN_DIALOGUE_TURNS} turns) 跳过: {skipped_short_dialogue}")
226
+ print(f"因疑似损坏跳过的对话数 (基于 is_potentially_garbled): {skipped_garbled_dialogue}")
227
+ print(f"提取出的有效用户提问样本数: {extracted_samples}")
228
+ return filtered_samples
229
+
230
+ # --- 主程序 (保持不变,但调用 V2 函数) ---
231
+ if __name__ == "__main__":
232
+ # 调用修正后的过滤函数
233
+ filtered_data_list = filter_ultrachat_dataset_v2(dataset_name="HuggingFaceH4/ultrachat_200k", split="train_sft")
234
+
235
+ if filtered_data_list:
236
+ filtered_dataset = Dataset.from_list(filtered_data_list)
237
+ # 更新输出目录名以反映新的过滤规则
238
+ output_path = "./ultrachat_filtered_for_tts_preference_v3_nocode"
239
+ print(f"\n正在保存过滤后的数据集到: {output_path}")
240
+ os.makedirs(output_path, exist_ok=True)
241
+ filtered_dataset.save_to_disk(output_path)
242
+ print("数据集保存完成.")
243
+
244
+ print("\n部分样本预览 (从保存的 Dataset 加载):")
245
+ try:
246
+ loaded_dataset = Dataset.load_from_disk(output_path)
247
+ for i in range(min(5, len(loaded_dataset))):
248
+ sample = loaded_dataset[i]
249
+ print(f"--- 样本 {i+1} (Dialogue ID: {sample['dialogue_id']}, Turn: {sample['turn_index']}) ---")
250
+ print(f"History (last 500 chars):\n...{sample['history'][-500:]}")
251
+ print(f"\nQuery: {sample['query']}")
252
+ print("-" * 20)
253
+ except Exception as e:
254
+ print(f"加载预览样本时出错: {e}")
255
+
256
+ else:
257
+ print("\n没有找到符合条件的样本。可能原因:")
258
+ print("1. 过滤参数过于严格 (检查 MIN/MAX word counts, SIMPLE_PROMPT_PATTERNS, 新增的代码/TTS过滤参数等)。")
259
+ print("2. `is_potentially_garbled` 规则误判。")
260
+ print("3. 数据集本身在此 split 中没有符合条件的对话。")
261
+ print("4. (请检查脚本输出的跳过计数,看是哪个阶段跳过了大量样本)")
r1-a/dataset/ultrachat_tts.py ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --- SET CUDA DEVICE ---
2
+ # Method 1: Set environment variable BEFORE importing torch/cosyvoice
3
+ # This makes only GPU 1 visible to the script, appearing as 'cuda:0' internally.
4
+ import os
5
+ os.environ["CUDA_VISIBLE_DEVICES"] = "1"
6
+ # --- End CUDA Device Setting ---
7
+
8
+ import random
9
+ import torch
10
+ import torchaudio
11
+ # Make sure necessary types are imported
12
+ from datasets import load_dataset, Dataset, load_from_disk, Features, Value
13
+ import sys
14
+ from tqdm import tqdm
15
+ import time
16
+ import shutil # Added for potentially removing old dataset save dirs
17
+
18
+ # Check if the specified GPU is available after setting the environment variable
19
+ if not torch.cuda.is_available():
20
+ print("ERROR: CUDA is not available after setting CUDA_VISIBLE_DEVICES='1'. Check your PyTorch installation, GPU drivers, and that GPU 1 exists and is functional.")
21
+ # Force exit if the intended GPU is not found
22
+ sys.exit(1)
23
+ else:
24
+ # Since CUDA_VISIBLE_DEVICES is set to '1', the first *visible* device is cuda:0
25
+ effective_device = torch.device("cuda:0")
26
+ try:
27
+ print(f"CUDA device visible to PyTorch: {torch.cuda.get_device_name(0)}") # Should show the name of GPU 1
28
+ print(f"Script will effectively run on: {effective_device}")
29
+ # Perform a small check to ensure the device is usable
30
+ _ = torch.tensor([1.0]).to(effective_device)
31
+ print("Device check successful.")
32
+ except Exception as e:
33
+ print(f"ERROR: Failed CUDA device check for visible device 'cuda:0' (original GPU 1): {e}")
34
+ sys.exit(1)
35
+
36
+
37
+ # Ensure CosyVoice path is correct
38
+ COSYVOICE_PATH = '/root/autodl-tmp/CosyVoice' # Make sure this path is correct
39
+ if not os.path.isdir(COSYVOICE_PATH):
40
+ print(f"ERROR: CosyVoice path not found: {COSYVOICE_PATH}")
41
+ sys.exit(1)
42
+ sys.path.append(COSYVOICE_PATH)
43
+
44
+ # Import CosyVoice *after* setting the environment variable
45
+ try:
46
+ from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
47
+ from cosyvoice.utils.file_utils import load_wav
48
+ print("CosyVoice imported successfully.")
49
+ except ImportError as e:
50
+ print(f"Error importing CosyVoice: {e}")
51
+ print(f"Please ensure the path '{COSYVOICE_PATH}' is correct and the library is installed within that directory.")
52
+ sys.exit(1)
53
+ except Exception as e:
54
+ print(f"An unexpected error occurred during CosyVoice import: {e}")
55
+ sys.exit(1)
56
+
57
+ # ------------------------
58
+ # 配置参数 (MODIFIED FOR Selected UltraChat DATASET)
59
+ # ------------------------
60
+ COMMON_VOICE_LANGUAGE = "en" # Language for prompts
61
+
62
+ # --- !! MODIFIED !! ---
63
+ # Input: Path to the SELECTED UltraChat dataset (Top 20%) from the previous script
64
+ INPUT_DATASET_PATH = "./ultrachat_final_top20_percent/ultrachat_top20_percent_by_complexity"
65
+ # Output: Directory to save new audio files and the final dataset object for THIS specific dataset
66
+ OUTPUT_DATASET_PATH = './ultrachat_top20_percent_with_query_audio' # New distinct output path
67
+ # --- End MODIFIED ---
68
+
69
+ SAMPLE_RATE = 16000 # Target sample rate for TTS output (should match CosyVoice default)
70
+ MAX_TTS_RETRIES = 3
71
+ RETRY_DELAY_SECONDS = 3
72
+
73
+ # ------------------------
74
+ # 辅助函数 (GPU handling and core TTS logic - UNCHANGED as requested)
75
+ # ------------------------
76
+ def get_random_prompt(common_voice_dataset, sample_rate=SAMPLE_RATE):
77
+ """
78
+ 从 VoxPopuli 数据集中随机抽取一条语音及对应文本作为 prompt。
79
+ (Logic remains unchanged from previous TTS script)
80
+ """
81
+ idx = random.randint(0, len(common_voice_dataset) - 1)
82
+ try:
83
+ sample = common_voice_dataset.select([idx]).with_format('numpy')[0]
84
+ audio = sample['audio']
85
+ waveform = torch.tensor(audio['array'], dtype=torch.float32) # CPU
86
+ sr = audio['sampling_rate']
87
+ if sr != sample_rate:
88
+ if waveform.dim() > 1: waveform = waveform.mean(dim=0)
89
+ if waveform.dim() != 1: return get_random_prompt(common_voice_dataset, sample_rate)
90
+ resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate)
91
+ waveform = resampler(waveform)
92
+ if waveform.dim() == 1: waveform = waveform.unsqueeze(0)
93
+ elif waveform.dim() > 2: return get_random_prompt(common_voice_dataset, sample_rate)
94
+ raw_text = sample.get('raw_text', '')
95
+ if waveform.numel() == 0 or not raw_text or not raw_text.strip():
96
+ return get_random_prompt(common_voice_dataset, sample_rate)
97
+ return waveform, raw_text # Return CPU tensor
98
+ except Exception as e:
99
+ time.sleep(0.1)
100
+ return get_random_prompt(common_voice_dataset, sample_rate)
101
+
102
+ def text_to_audio(text_to_convert, cosyvoice, common_voice_dataset, stream=False, max_retries=MAX_TTS_RETRIES):
103
+ """
104
+ 利用 CosyVoice2 模型将输入文本转换为语音,采用随机 prompt 进行零样本推理。
105
+ Includes retry logic on failure. Assumes cosyvoice runs on the configured device.
106
+ (Logic remains unchanged from previous TTS script)
107
+ """
108
+ last_exception = None
109
+ prompt_speech, prompt_text = None, "N/A"
110
+ for attempt in range(max_retries):
111
+ try:
112
+ prompt_speech, prompt_text = get_random_prompt(common_voice_dataset, sample_rate=SAMPLE_RATE) # CPU tensor
113
+ all_speech = []
114
+ inference_generator = cosyvoice.inference_zero_shot(
115
+ text_to_convert, prompt_text, prompt_speech, stream=stream, text_frontend=False
116
+ )
117
+ for i, chunk in enumerate(inference_generator): # Chunks on GPU
118
+ if chunk is None: continue
119
+ if 'tts_speech' in chunk and chunk['tts_speech'] is not None and chunk['tts_speech'].numel() > 0:
120
+ gpu_chunk = chunk['tts_speech'].to(effective_device)
121
+ all_speech.append(gpu_chunk)
122
+ if not all_speech:
123
+ if torch.cuda.is_available(): torch.cuda.empty_cache()
124
+ raise ValueError("TTS inference finished but produced no valid audio chunks.")
125
+ combined_speech = torch.cat(all_speech, dim=-1) # GPU tensor
126
+ sample_rate_val = cosyvoice.sample_rate
127
+ if torch.max(torch.abs(combined_speech)) < 0.001:
128
+ raise ValueError("Generated audio is silent")
129
+ return {'audio_tensor': combined_speech, 'sample_rate': sample_rate_val} # Return GPU tensor
130
+ except Exception as e:
131
+ last_exception = e
132
+ print(f"Error converting text to audio on attempt {attempt + 1}/{max_retries}: {e}")
133
+ print(f" Text: '{text_to_convert[:100]}...'")
134
+ # print(f" Prompt Text Used: '{prompt_text[:100]}...'") # Reduce log noise
135
+ if torch.cuda.is_available(): torch.cuda.empty_cache()
136
+ if attempt < max_retries - 1:
137
+ print(f" Retrying with a different prompt in {RETRY_DELAY_SECONDS}s...")
138
+ time.sleep(RETRY_DELAY_SECONDS)
139
+ else:
140
+ print(f" All {max_retries} TTS attempts failed.")
141
+ return None
142
+
143
+ # --- PROCESS EXAMPLE (Targets 'query' field) ---
144
+ def process_example(example, cosyvoice, common_voice_dataset, sample_rate=SAMPLE_RATE):
145
+ """
146
+ 针对从磁盘加载的 *Selected Top 20% UltraChat* 数据集中的单个样本进行 TTS 处理。
147
+ Processes the example['query'] field.
148
+ """
149
+ text_to_convert = example.get('query')
150
+ # Get identifiers for logging, if they exist in this dataset version
151
+ dialogue_id = example.get('dialogue_id', 'N/A')
152
+ turn_index = example.get('turn_index', 'N/A') # May not be present if not carried over
153
+
154
+ if not text_to_convert or not isinstance(text_to_convert, str) or not text_to_convert.strip():
155
+ print(f"Warning: Skipping example (ID: {dialogue_id}, Turn: {turn_index}) due to missing or empty 'query' field.")
156
+ return None
157
+
158
+ # Call the unchanged text_to_audio function
159
+ audio_result = text_to_audio(text_to_convert, cosyvoice, common_voice_dataset, stream=False)
160
+
161
+ if audio_result is not None:
162
+ audio_tensor = audio_result['audio_tensor'] # Still on GPU here
163
+ if audio_tensor is None or audio_tensor.numel() == 0:
164
+ print(f"Warning: TTS process returned empty tensor for query (ID: {dialogue_id}, Turn: {turn_index}): '{text_to_convert[:60]}...'")
165
+ return None
166
+ if audio_tensor.dim() == 1: audio_tensor = audio_tensor.unsqueeze(0)
167
+ elif audio_tensor.dim() > 2:
168
+ print(f"Warning: Generated audio tensor unexpected shape {audio_tensor.shape} (ID: {dialogue_id}, Turn: {turn_index}). Flattening.")
169
+ audio_tensor = audio_tensor.view(1, -1) # Flatten to [1, T]
170
+ if audio_tensor.numel() == 0:
171
+ print(f"Warning: Generated audio tensor became empty after reshape for query (ID: {dialogue_id}, Turn: {turn_index}): '{text_to_convert[:60]}...'")
172
+ return None
173
+ return {
174
+ 'audio_tensor': audio_tensor, # Return GPU tensor
175
+ 'sample_rate': audio_result['sample_rate']
176
+ }
177
+ else:
178
+ return None # Errors logged within text_to_audio
179
+
180
+ # ------------------------
181
+ # 数据加载与模型初始化 (Model and Prompt Dataset Loading Unchanged)
182
+ # ------------------------
183
+ print("Loading VoxPopuli (as Common Voice) dataset for prompts...")
184
+ try:
185
+ common_voice = load_dataset("facebook/voxpopuli", COMMON_VOICE_LANGUAGE, split='train', trust_remote_code=True)
186
+ common_voice = common_voice.filter(lambda x: x['audio'] is not None and x['audio']['array'] is not None and x['raw_text'] is not None and len(x['raw_text'].strip()) > 5 and len(x['audio']['array']) > SAMPLE_RATE * 0.5)
187
+ print(f"Loaded and filtered VoxPopuli '{COMMON_VOICE_LANGUAGE}' samples: {len(common_voice)}")
188
+ if len(common_voice) == 0: raise ValueError(f"VoxPopuli '{COMMON_VOICE_LANGUAGE}' loaded but no valid samples after filtering.")
189
+ except Exception as e:
190
+ print(f"Error loading or filtering VoxPopuli dataset: {e}")
191
+ sys.exit(1)
192
+
193
+ print("Initializing CosyVoice2 model...")
194
+ try:
195
+ # CosyVoice initialization remains the same
196
+ cosyvoice_model_path = os.path.join(COSYVOICE_PATH, 'pretrained_models/CosyVoice2-0.5B')
197
+ if not os.path.isdir(cosyvoice_model_path): raise FileNotFoundError(f"CosyVoice pretrained model directory not found: {cosyvoice_model_path}")
198
+ cosyvoice = CosyVoice2(
199
+ cosyvoice_model_path, load_jit=True, load_trt=False, fp16=False
200
+ )
201
+ print(f"CosyVoice model initialized. Target device: {effective_device}")
202
+ except Exception as e:
203
+ print(f"Error initializing CosyVoice2 model: {e}")
204
+ if isinstance(e, RuntimeError) and 'CUDA' in str(e): print("CUDA initialization error? Check GPU 1 status/memory.")
205
+ sys.exit(1)
206
+
207
+ # --- !! MODIFIED Selected UltraChat Dataset Loading !! ---
208
+ print(f"\nLoading the target Selected UltraChat (Top 20%) dataset from disk: {INPUT_DATASET_PATH}")
209
+ if not os.path.exists(INPUT_DATASET_PATH):
210
+ print(f"Error: Input dataset directory not found at '{INPUT_DATASET_PATH}'.")
211
+ print("Please ensure the UltraChat Selection script ran successfully and produced the dataset at this location.")
212
+ sys.exit(1)
213
+
214
+ try:
215
+ input_dataset = load_from_disk(INPUT_DATASET_PATH)
216
+
217
+ print(f"Successfully loaded Selected UltraChat dataset with {len(input_dataset)} examples.")
218
+ if len(input_dataset) == 0:
219
+ print("Error: The loaded dataset is empty. Cannot proceed.")
220
+ sys.exit(1)
221
+ # Store original features to reconstruct the final dataset correctly
222
+ original_features = input_dataset.features
223
+ print(f"Original features: {original_features}")
224
+ # Check for 'query' column existence (essential for TTS)
225
+ if 'query' not in original_features:
226
+ print(f"Error: The loaded dataset from '{INPUT_DATASET_PATH}' does not contain the required 'query' column.")
227
+ sys.exit(1)
228
+
229
+ except Exception as e:
230
+ print(f"Error loading dataset from '{INPUT_DATASET_PATH}': {e}")
231
+ sys.exit(1)
232
+ # --- End MODIFIED Dataset Loading ---
233
+
234
+
235
+ # --- Create output directories ---
236
+ os.makedirs(OUTPUT_DATASET_PATH, exist_ok=True)
237
+ audio_output_dir = os.path.join(OUTPUT_DATASET_PATH, "audio_files")
238
+ os.makedirs(audio_output_dir, exist_ok=True)
239
+ print(f"Audio files will be saved in: {audio_output_dir}")
240
+ progress_file = os.path.join(OUTPUT_DATASET_PATH, "progress.txt")
241
+ print(f"Progress will be tracked in: {progress_file}")
242
+
243
+
244
+ # ------------------------
245
+ # 主处理循环 (MODIFIED FOR SINGLE Selected UltraChat DATASET)
246
+ # ------------------------
247
+ # --- !! MODIFIED: Update log message !! ---
248
+ print(f"\nStarting TTS processing for {len(input_dataset)} Selected UltraChat (Top 20%) samples...")
249
+
250
+ start_index = 0
251
+ # Read progress file to resume if necessary
252
+ if os.path.exists(progress_file):
253
+ try:
254
+ with open(progress_file, "r") as f:
255
+ content = f.read().strip()
256
+ if content: start_index = int(content)
257
+ print(f"Resuming TTS processing from sample index {start_index}")
258
+ except Exception as e:
259
+ print(f"Error reading progress file '{progress_file}': {e}. Starting TTS from index 0.")
260
+ start_index = 0
261
+
262
+ # List to hold dictionaries for the final dataset
263
+ final_samples = []
264
+
265
+ # --- Main Loop ---
266
+ # --- !! MODIFIED: Update progress bar description !! ---
267
+ pbar = tqdm(range(start_index, len(input_dataset)), desc=f"TTS on Selected UltraChat 'query'", initial=start_index, total=len(input_dataset))
268
+ for i in pbar:
269
+ sample = input_dataset[i] # Get sample dictionary (on CPU)
270
+
271
+ # Define unique output WAV path using the index
272
+ output_wav_filename = f"query_{i}.wav"
273
+ output_wav_path = os.path.join(audio_output_dir, output_wav_filename)
274
+
275
+ # --- Check if audio file already exists ---
276
+ if os.path.exists(output_wav_path):
277
+ sample_dict = dict(sample)
278
+ sample_dict["query_audio_filepath"] = output_wav_path # Add path field
279
+ final_samples.append(sample_dict)
280
+ with open(progress_file, "w") as f: f.write(str(i + 1))
281
+ continue # Skip TTS
282
+
283
+ # --- Perform TTS on the target device ---
284
+ result = process_example(sample, cosyvoice, common_voice, sample_rate=SAMPLE_RATE)
285
+
286
+ if result is not None and 'audio_tensor' in result and result['audio_tensor'] is not None:
287
+ audio_tensor = result['audio_tensor'] # GPU tensor
288
+ sample_rate_val = result['sample_rate']
289
+ try:
290
+ # Move tensor to CPU before saving
291
+ audio_tensor_save = audio_tensor.detach().cpu().to(torch.float32)
292
+ if audio_tensor_save.dim() == 1: audio_tensor_save = audio_tensor_save.unsqueeze(0)
293
+ elif audio_tensor_save.dim() > 2: audio_tensor_save = audio_tensor_save.view(1, -1)
294
+
295
+ torchaudio.save(output_wav_path, audio_tensor_save, sample_rate_val)
296
+
297
+ # Create dict for the final dataset
298
+ sample_dict = dict(sample)
299
+ sample_dict["query_audio_filepath"] = output_wav_path # Add path field
300
+ final_samples.append(sample_dict)
301
+
302
+ del audio_tensor # Delete GPU tensor
303
+
304
+ except Exception as e:
305
+ # Log error with identifiers if available
306
+ dialogue_id = sample.get('dialogue_id', 'N/A')
307
+ turn_index = sample.get('turn_index', 'N/A')
308
+ print(f"Failed to save wav for sample {i} (ID: {dialogue_id}, Turn: {turn_index}) at {output_wav_path}: {e}")
309
+ if os.path.exists(output_wav_path):
310
+ try: os.remove(output_wav_path)
311
+ except OSError: pass
312
+ if torch.cuda.is_available(): torch.cuda.empty_cache()
313
+ else:
314
+ # Failure logged in process_example/text_to_audio
315
+ if torch.cuda.is_available(): torch.cuda.empty_cache()
316
+
317
+ # --- Update progress file ---
318
+ with open(progress_file, "w") as f: f.write(str(i + 1))
319
+
320
+ # --- Optional: Periodic cache clearing ---
321
+ if i > 0 and i % 50 == 0:
322
+ if torch.cuda.is_available(): torch.cuda.empty_cache()
323
+
324
+
325
+ # --- Final cache clear after finishing the loop ---
326
+ if torch.cuda.is_available():
327
+ print("Clearing final CUDA cache...")
328
+ torch.cuda.empty_cache()
329
+
330
+ # ------------------------
331
+ # 保存最终数据集 (MODIFIED FOR Selected UltraChat)
332
+ # ------------------------
333
+ print("\nTTS processing loop finished.")
334
+ if final_samples:
335
+ # --- !! MODIFIED: Update log message !! ---
336
+ print(f"Successfully processed (or skipped existing) {len(final_samples)} Selected UltraChat (Top 20%) samples.")
337
+
338
+ # --- Define features for the new dataset ---
339
+ new_features_dict = original_features.copy()
340
+ new_column_name = 'query_audio_filepath' # Name of the new column
341
+ if new_column_name in new_features_dict:
342
+ print(f"Warning: Feature '{new_column_name}' already exists in original features. Overwriting.")
343
+ new_features_dict[new_column_name] = Value('string') # Add the new column definition
344
+ try:
345
+ new_features = Features(new_features_dict)
346
+ print(f"Defined new features for saving: {new_features}")
347
+
348
+ # --- Create the final Dataset object ---
349
+ print("Creating final Dataset object from processed samples...")
350
+ final_dataset_obj = Dataset.from_list(final_samples, features=new_features)
351
+
352
+ # --- Define path to save the final dataset metadata object ---
353
+ final_dataset_save_path = os.path.join(OUTPUT_DATASET_PATH, "processed_dataset_with_audio")
354
+ # --- !! MODIFIED: Update log message !! ---
355
+ print(f"Saving final Selected UltraChat (Top 20%) dataset metadata (with audio paths) to: {final_dataset_save_path}...")
356
+
357
+ # Ensure the target directory exists and is empty before saving
358
+ if os.path.exists(final_dataset_save_path):
359
+ print(f"Removing existing directory before saving: {final_dataset_save_path}")
360
+ shutil.rmtree(final_dataset_save_path)
361
+
362
+ final_dataset_obj.save_to_disk(final_dataset_save_path)
363
+ print(f"Final dataset object saved successfully.")
364
+
365
+ except Exception as e:
366
+ print(f"\n!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
367
+ print(f"Error during final dataset creation or saving: {e}")
368
+ print(f"Audio files might be saved in '{audio_output_dir}', but the final dataset object could not be created/saved.")
369
+ print(f"Check the features and the content of 'final_samples'.")
370
+ print(f"!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
371
+
372
+ else:
373
+ print("Processing finished, but no samples were successfully processed or had existing audio files.")
374
+ print(f"Check logs for TTS errors. Audio files directory: '{audio_output_dir}'.")
375
+
376
+
377
+ print("\n" + "="*60)
378
+ # --- !! MODIFIED: Update final log messages !! ---
379
+ print(f"Script finished for Selected UltraChat (Top 20%) dataset.")
380
+ print(f"Generated audio files are located in: '{audio_output_dir}'")
381
+ print(f"The final dataset (metadata + audio file paths) is saved at: '{os.path.join(OUTPUT_DATASET_PATH, 'processed_dataset_with_audio')}' (if saving was successful)")
382
+ print("="*60)
r1-a/prompt_only_examine.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_from_disk
2
+ import torchaudio
3
+ import os
4
+
5
+ # IMPORTANT: When you load and use this dataset, your CWD should have the same
6
+ # relationship to the audio files as it did when the dataset was created,
7
+ # OR you need to manually resolve the paths.
8
+
9
+ # Path where the HF dataset was saved
10
+ dataset_path = '/root/autodl-tmp/audio-r1/r1-a/dataset/prompt_only_fully_merged_with_audio/final_hf_dataset_relative_paths_v4'
11
+ ds = load_from_disk(dataset_path)
12
+ breakpoint() # Use this to inspect the dataset structure if needed
13
+ # Get a relative path string from the dataset
14
+ relative_path_from_dataset = ds[0]['question_audio_relative_path'] # Or your field name
15
+ print(f"Relative path from dataset: {relative_path_from_dataset}")
16
+
17
+ # To load the audio, this relative path needs to resolve correctly from your *current* CWD
18
+ # Option 1: If your CWD is correct
19
+ # current_cwd_when_loading = os.getcwd()
20
+ # print(f"Current CWD when loading: {current_cwd_when_loading}")
21
+ # full_path_to_audio = os.path.abspath(relative_path_from_dataset) # os.path.abspath resolves based on CWD
22
+
23
+ # Option 2: If you know the dataset's "root" directory relative to which paths were made
24
+ # This is safer if you move the dataset and audio files together.
25
+ # Assume the dataset was created when CWD was '/root/autodl-tmp/audio-r1/r1-a/dataset/'
26
+ # And now you are running this loading script from somewhere else, but you know that 'root'.
27
+ # For example, if your audio files are now located such that the relative path still makes sense
28
+ # if prepended by a new base_dir.
29
+
30
+ # Example: If you know the original CWD when the dataset was created,
31
+ # and you want to reconstruct the absolute path assuming the audio files haven't moved
32
+ # This is generally what os.path.join(original_cwd, relative_path) would give if files are static
33
+ # However, if you've moved the dataset AND audio files together, keeping their relative structure,
34
+ # then the relative_path_from_dataset should resolve correctly if your CWD is the new "root" of that structure.
35
+
36
+ # The simplest way if you are in the correct CWD when loading:
37
+ full_path_to_audio = os.path.join(os.getcwd(), relative_path_from_dataset) # This might not be right if rel path has '..'
38
+ full_path_to_audio = os.path.abspath(relative_path_from_dataset) # This is usually what you want if CWD is the intended base
39
+
40
+ print(f"Attempting to load from (resolved path): {full_path_to_audio}")
41
+
42
+ if os.path.exists(full_path_to_audio):
43
+ waveform, sample_rate = torchaudio.load(full_path_to_audio)
44
+ print(f"Loaded audio: waveform shape {waveform.shape}, sample rate {sample_rate}")
45
+ else:
46
+ print(f"Audio file NOT FOUND at resolved path: {full_path_to_audio}")
47
+ print("Ensure your Current Working Directory is set correctly so the relative path can be resolved.")
48
+ print(f"Alternatively, manually construct the absolute path if you know where the audio files are relative to a fixed base.")