diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..53aecba8e5155fc7916705d1fbe6268dabb3cee5 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,30 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +r1-a/dataset/gsm8k_with_audio/test/450.wav filter=lfs diff=lfs merge=lfs -text +r1-a/dataset/gsm8k_with_audio/test/45.wav filter=lfs diff=lfs merge=lfs -text +r1-a/dataset/gsm8k_with_audio/test/358.wav filter=lfs diff=lfs merge=lfs -text +r1-a/dataset/gsm8k_with_audio/test/369.wav filter=lfs diff=lfs merge=lfs -text +r1-a/dataset/gsm8k_with_audio/test/316.wav filter=lfs diff=lfs merge=lfs -text +r1-a/dataset/gsm8k_with_audio/test/454.wav filter=lfs diff=lfs merge=lfs -text +r1-a/dataset/gsm8k_with_audio/test/376.wav filter=lfs diff=lfs merge=lfs -text +r1-a/dataset/gsm8k_with_audio/test/395.wav filter=lfs diff=lfs merge=lfs -text +r1-a/dataset/gsm8k_with_audio/test/359.wav filter=lfs diff=lfs merge=lfs -text +r1-a/dataset/gsm8k_with_audio/test/447.wav filter=lfs diff=lfs merge=lfs -text +r1-a/dataset/gsm8k_with_audio/test/299.wav filter=lfs diff=lfs merge=lfs -text +r1-a/dataset/gsm8k_with_audio/test/302.wav filter=lfs diff=lfs merge=lfs -text +r1-a/dataset/gsm8k_with_audio/test/394.wav filter=lfs diff=lfs merge=lfs -text +r1-a/dataset/gsm8k_with_audio/test/350.wav filter=lfs diff=lfs merge=lfs -text +r1-a/dataset/gsm8k_with_audio/test/385.wav filter=lfs diff=lfs merge=lfs -text +r1-a/dataset/gsm8k_with_audio/test/401.wav filter=lfs diff=lfs merge=lfs -text +r1-a/dataset/gsm8k_with_audio/test/463.wav filter=lfs diff=lfs merge=lfs -text +r1-a/dataset/gsm8k_with_audio/test/458.wav filter=lfs diff=lfs merge=lfs -text +r1-a/dataset/gsm8k_with_audio/test/457.wav filter=lfs diff=lfs merge=lfs -text +r1-a/dataset/gsm8k_with_audio/test/515.wav filter=lfs diff=lfs merge=lfs -text +r1-a/dataset/gsm8k_with_audio/test/301.wav filter=lfs diff=lfs merge=lfs -text +r1-a/dataset/gsm8k_with_audio/test/372.wav filter=lfs diff=lfs merge=lfs -text +r1-a/dataset/gsm8k_with_audio/test/314.wav filter=lfs diff=lfs merge=lfs -text +r1-a/dataset/gsm8k_with_audio/test/397.wav filter=lfs diff=lfs merge=lfs -text +r1-a/dataset/gsm8k_with_audio/test/465.wav filter=lfs diff=lfs merge=lfs -text +r1-a/dataset/gsm8k_with_audio/test/459.wav filter=lfs diff=lfs merge=lfs -text +r1-a/dataset/gsm8k_with_audio/test/400.wav filter=lfs diff=lfs merge=lfs -text diff --git a/r1-a/dataset/ai2_arc.py b/r1-a/dataset/ai2_arc.py new file mode 100644 index 0000000000000000000000000000000000000000..ccedcdfcfb0dad277a0472d8e1e138376c81d403 --- /dev/null +++ b/r1-a/dataset/ai2_arc.py @@ -0,0 +1,175 @@ +import os +import random +import torch +import torchaudio +from datasets import load_dataset, Dataset +import sys +from tqdm import tqdm + +sys.path.append('/root/autodl-tmp/CosyVoice') +from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2 +from cosyvoice.utils.file_utils import load_wav + +# ------------------------ +# 配置参数 +# ------------------------ +COMMON_VOICE_LANGUAGE = "en" +DATASET_NAME = "ai2_arc" +OUTPUT_DATASET_PATH = './arc_easy_with_audio' # 输出目录 +SAMPLE_RATE = 16000 + +# ------------------------ +# 辅助函数 +# ------------------------ +def get_random_prompt(common_voice_dataset, sample_rate=SAMPLE_RATE): + """ + 从 VoxPopuli (替代原 Common Voice) 数据集中随机抽取一条语音及对应文本作为 prompt。 + """ + idx = random.randint(0, len(common_voice_dataset) - 1) + sample = common_voice_dataset.select([idx])[0] + audio = sample['audio'] + waveform = torch.tensor(audio['array'], dtype=torch.float32) + sr = audio['sampling_rate'] + if sr != sample_rate: + resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate) + waveform = resampler(waveform) + return waveform.unsqueeze(0), sample['raw_text'] + +def text_to_audio(query_text, cosyvoice, common_voice_dataset, stream=False): + """ + 利用 CosyVoice2 模型将输入文本转换为语音,采用随机 prompt 进行零样本推理。 + """ + try: + prompt_speech, prompt_text = get_random_prompt(common_voice_dataset, sample_rate=SAMPLE_RATE) + + all_speech = [] + for i, j in enumerate(cosyvoice.inference_zero_shot( + query_text, + prompt_text, + prompt_speech, + stream=stream, + text_frontend=False + )): + all_speech.append(j['tts_speech']) + + # 将所有生成的语音片段拼接在一起 + combined_speech = torch.cat(all_speech, dim=-1) + sample_rate_val = cosyvoice.sample_rate + + return { + 'audio_tensor': combined_speech, + 'sample_rate': sample_rate_val + } + except Exception as e: + print(f"Error converting text to audio: {e}") + return None + +def process_example(example, cosyvoice, common_voice_dataset, sample_rate=SAMPLE_RATE): + """ + 针对 AI2 ARC 数据集中的单个样本进行 TTS 处理。 + 在此示例中,仅对 sample['question'] 字段执行 TTS。 + """ + query = example['question'] + audio_result = text_to_audio(query, cosyvoice, common_voice_dataset, stream=False) + if audio_result is not None: + return { + 'audio_tensor': audio_result['audio_tensor'], + 'sample_rate': audio_result['sample_rate'] + } + else: + return None + +# ------------------------ +# 数据加载与模型初始化 +# ------------------------ +print("Loading VoxPopuli (as Common Voice) dataset...") +common_voice = load_dataset("facebook/voxpopuli", "en", split='train') +print(f"Total VoxPopuli {COMMON_VOICE_LANGUAGE} samples: {len(common_voice)}") + +print("Initializing CosyVoice2 model...") +cosyvoice = CosyVoice2( + '/root/autodl-tmp/CosyVoice/pretrained_models/CosyVoice2-0.5B', # 替换为实际模型路径 + load_jit=True, + load_trt=False, + fp16=False +) + +print("Loading ARC-Challenge dataset...") +# 如果想处理 ARC-Easy,只需改为 "ARC-Easy" +dataset = load_dataset("allenai/ai2_arc", "ARC-Easy") + +# 创建输出目录 +os.makedirs(OUTPUT_DATASET_PATH, exist_ok=True) + +# ------------------------ +# 主处理循环 +# ------------------------ +final_dataset_dict = {} # 存放各 split 最终处理后的数据 + +for split_name, split_dataset in dataset.items(): + print(f"Processing split: {split_name} with {len(split_dataset)} examples") + split_output_dir = os.path.join(OUTPUT_DATASET_PATH, split_name) + os.makedirs(split_output_dir, exist_ok=True) + + # 用于断点续跑的进度记录 + progress_file = os.path.join(split_output_dir, "progress.txt") + start_index = 0 + if os.path.exists(progress_file): + try: + with open(progress_file, "r") as f: + start_index = int(f.read().strip()) + print(f"Resuming split '{split_name}' from sample index {start_index}") + except Exception as e: + print(f"读取进度文件失败:{e}") + + final_samples = [] + + # 遍历处理每条样本 + for i in tqdm(range(len(split_dataset)), desc=f"Processing {split_name}"): + # 如果已处理过,就直接跳过并仅把已存在文件的元信息记入 final_samples + if i < start_index: + sample = split_dataset[i] + wav_path = os.path.join(split_output_dir, f"{i}.wav") + if os.path.exists(wav_path): + # 保留所有原始字段 + 音频路径 + sample_dict = {k: sample[k] for k in sample.keys()} + sample_dict["audio_filepath"] = wav_path + final_samples.append(sample_dict) + continue + + sample = split_dataset[i] + result = process_example(sample, cosyvoice, common_voice, sample_rate=SAMPLE_RATE) + + if result is not None: + audio_tensor = result['audio_tensor'] + if audio_tensor.dim() == 1: + audio_tensor = audio_tensor.unsqueeze(0) + sample_rate_val = result['sample_rate'] + + output_wav_path = os.path.join(split_output_dir, f"{i}.wav") + try: + torchaudio.save(output_wav_path, audio_tensor, sample_rate_val) + except Exception as e: + print(f"Failed to save wav for sample {i}: {e}") + continue + + # 保留所有原始字段 + 生成的音频路径 + sample_dict = {k: sample[k] for k in sample.keys()} + sample_dict["audio_filepath"] = output_wav_path + final_samples.append(sample_dict) + else: + print(f"Sample {i} processing failed, no audio generated.") + + # 更新进度记录 + with open(progress_file, "w") as f: + f.write(str(i + 1)) + + # 生成 Hugging Face Dataset 并落盘 + final_dataset_obj = Dataset.from_list(final_samples) + final_dataset_save_path = os.path.join(split_output_dir, "final_dataset") + final_dataset_obj.save_to_disk(final_dataset_save_path) + + print(f"Finished processing split: {split_name} with {len(final_samples)} final samples.") + final_dataset_dict[split_name] = final_dataset_obj + +print("所有分割处理完毕,最终数据集已保存。") diff --git a/r1-a/dataset/alpaca.py b/r1-a/dataset/alpaca.py new file mode 100644 index 0000000000000000000000000000000000000000..bbe988bb26b15b64c7765473e78fbc2806fe69f8 --- /dev/null +++ b/r1-a/dataset/alpaca.py @@ -0,0 +1,346 @@ +# --- SET CUDA DEVICE --- +# Method 1: Set environment variable BEFORE importing torch/cosyvoice +# This makes only GPU 1 visible to the script, appearing as 'cuda:0' internally. +import os +os.environ["CUDA_VISIBLE_DEVICES"] = "1" +# --- End CUDA Device Setting --- + +import random +import torch +import torchaudio +from datasets import load_dataset, Dataset, load_from_disk +import sys +from tqdm import tqdm +import time + +# Check if the specified GPU is available after setting the environment variable +if not torch.cuda.is_available(): + print("WARNING: CUDA is not available after setting CUDA_VISIBLE_DEVICES='1'. Check your PyTorch installation and GPU drivers.") + print("Attempting to run on CPU, but this will be very slow.") + # Decide if you want to exit or proceed on CPU + # sys.exit(1) # Uncomment to exit if GPU not found + effective_device = torch.device("cpu") +else: + # Since CUDA_VISIBLE_DEVICES is set to '1', the first *visible* device is cuda:0 + effective_device = torch.device("cuda:0") + print(f"CUDA device visible to PyTorch: {torch.cuda.get_device_name(0)}") # Should show the name of GPU 1 + print(f"Script will effectively run on: {effective_device}") + + +sys.path.append('/root/autodl-tmp/CosyVoice') # Make sure this path is correct +# Import CosyVoice *after* setting the environment variable +try: + from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2 + from cosyvoice.utils.file_utils import load_wav +except ImportError as e: + print(f"Error importing CosyVoice: {e}") + print("Please ensure the path '/root/autodl-tmp/CosyVoice' is correct and the library is installed.") + sys.exit(1) + +# ------------------------ +# 配置参数 +# ------------------------ +COMMON_VOICE_LANGUAGE = "en" +FILTERED_ALPACA_PATH = './alpaca_filtered_for_spoken_dialogue_v2' +SPLITS_TO_PROCESS = ['train'] +OUTPUT_DATASET_PATH = './alpaca_filtered_spoken_with_output_audio' # Keep output path distinct +SAMPLE_RATE = 16000 +MAX_TTS_RETRIES = 3 +RETRY_DELAY_SECONDS = 2 + +# ------------------------ +# 辅助函数 (No changes needed here, should run on the visible device) +# ------------------------ +def get_random_prompt(common_voice_dataset, sample_rate=SAMPLE_RATE): + """ + 从 VoxPopuli (替代原 Common Voice) 数据集中随机抽取一条语音及对应文本作为 prompt。 + """ + idx = random.randint(0, len(common_voice_dataset) - 1) + sample = common_voice_dataset.select([idx])[0] + audio = sample['audio'] + waveform = torch.tensor(audio['array'], dtype=torch.float32) # Created on CPU + sr = audio['sampling_rate'] + if sr != sample_rate: + if waveform.dim() > 1: + waveform = waveform.mean(dim=0) + resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate) + waveform = resampler(waveform) + if waveform.dim() == 1: + waveform = waveform.unsqueeze(0) + if waveform.numel() == 0 or not sample['raw_text']: + print("Warning: Got an empty prompt, trying again...") + return get_random_prompt(common_voice_dataset, sample_rate) + # Return CPU tensor, CosyVoice inference should handle moving it + return waveform, sample['raw_text'] + +def text_to_audio(query_text, cosyvoice, common_voice_dataset, stream=False, max_retries=MAX_TTS_RETRIES): + """ + 利用 CosyVoice2 模型将输入文本转换为语音,采用随机 prompt 进行零样本推理。 + Includes retry logic on failure. Assumes cosyvoice runs on the configured device. + """ + last_exception = None + for attempt in range(max_retries): + try: + # prompt_speech is initially on CPU + prompt_speech, prompt_text = get_random_prompt(common_voice_dataset, sample_rate=SAMPLE_RATE) + + all_speech = [] + # cosyvoice.inference_zero_shot should internally use the GPU device it was initialized on + # (which should be the visible cuda:0, i.e., original cuda:1) + inference_generator = cosyvoice.inference_zero_shot( + query_text, + prompt_text, + prompt_speech, # Pass CPU tensor + stream=stream, + text_frontend=False + ) + # Generated chunks 'tts_speech' will be on the GPU + for i, chunk in enumerate(inference_generator): + if 'tts_speech' in chunk and chunk['tts_speech'] is not None: + all_speech.append(chunk['tts_speech']) + else: + print(f"Warning: Chunk {i} missing 'tts_speech' for text '{query_text[:60]}...'") + + if not all_speech: + # Clear GPU memory cache if an error occurs during generation + if torch.cuda.is_available(): torch.cuda.empty_cache() + raise ValueError("TTS inference finished but produced no audio chunks.") + + # combined_speech is on GPU + combined_speech = torch.cat(all_speech, dim=-1) + sample_rate_val = cosyvoice.sample_rate + + return { + # Return GPU tensor, will be moved to CPU before saving + 'audio_tensor': combined_speech, + 'sample_rate': sample_rate_val + } + except Exception as e: + last_exception = e + print(f"Error converting text to audio on attempt {attempt + 1}/{max_retries}: {e}") + print(f"Text: '{query_text[:100]}...'") + print(f"Prompt Text: '{prompt_text[:100]}...'") + # Clear GPU cache on error as well + if torch.cuda.is_available(): torch.cuda.empty_cache() + if attempt < max_retries - 1: + print(f"Retrying with a different prompt in {RETRY_DELAY_SECONDS}s...") + time.sleep(RETRY_DELAY_SECONDS) + else: + print(f"All {max_retries} TTS attempts failed.") + + print(f"Failed to generate audio for text after {max_retries} attempts: '{query_text[:60]}...'") + print(f"Last error: {last_exception}") + return None + +def process_example(example, cosyvoice, common_voice_dataset, sample_rate=SAMPLE_RATE): + """ + 针对从磁盘加载的过滤后 Alpaca 数据集中的单个样本进行 TTS 处理。 + Processes example['output']. + """ + text_to_convert = example.get('instruction')+example.get('input') + if not text_to_convert or not isinstance(text_to_convert, str) or text_to_convert.strip() == "": + print(f"Warning: Skipping example due to missing or empty 'output' field: {example.keys()}") + return None + + audio_result = text_to_audio(text_to_convert, cosyvoice, common_voice_dataset, stream=False) + + if audio_result is not None: + audio_tensor = audio_result['audio_tensor'] # Still on GPU here + if audio_tensor.dim() == 1: + audio_tensor = audio_tensor.unsqueeze(0) + elif audio_tensor.dim() > 2: + print(f"Warning: Generated audio tensor has unexpected shape {audio_tensor.shape}. Attempting to flatten.") + audio_tensor = audio_tensor.view(1, -1) + + if audio_tensor.numel() == 0: + print(f"Warning: Generated audio tensor is empty for output text: '{text_to_convert[:60]}...'") + # Clear GPU cache even for empty tensor? Maybe not needed. + return None + + return { + 'audio_tensor': audio_tensor, # Return GPU tensor + 'sample_rate': audio_result['sample_rate'] + } + else: + return None + +# ------------------------ +# 数据加载与模型初始化 +# ------------------------ +print("Loading VoxPopuli (as Common Voice) dataset for prompts...") +try: + # Load prompt dataset to CPU memory + common_voice = load_dataset("facebook/voxpopuli", "en", split='train') + print(f"Total VoxPopuli {COMMON_VOICE_LANGUAGE} samples: {len(common_voice)}") + if len(common_voice) == 0: + raise ValueError("VoxPopuli dataset loaded but contains no samples.") +except Exception as e: + print(f"Error loading VoxPopuli dataset: {e}") + sys.exit(1) + + +print("Initializing CosyVoice2 model...") +try: + # CosyVoice should automatically initialize on the visible device ('cuda:0', which is original 'cuda:1') + # No explicit device='cuda:1' needed here due to CUDA_VISIBLE_DEVICES + cosyvoice = CosyVoice2( + '/root/autodl-tmp/CosyVoice/pretrained_models/CosyVoice2-0.5B', + load_jit=True, + load_trt=False, # Ensure TRT is False if not set up for GPU 1 + fp16=False # Check if GPU 1 supports FP16 well if you enable this + # device=effective_device # Usually not needed if CUDA_VISIBLE_DEVICES is set, but uncomment if CosyVoice requires it explicitly + ) + print(f"CosyVoice model initialized. It should be using device: {effective_device}") +except Exception as e: + print(f"Error initializing CosyVoice2 model: {e}") + # Try to get more info if it's a CUDA error + if isinstance(e, RuntimeError) and 'CUDA' in str(e): + print("This might be a CUDA initialization error. Ensure GPU 1 is functional and has enough memory.") + sys.exit(1) + +print(f"Loading pre-filtered Alpaca dataset(s) from disk: {FILTERED_ALPACA_PATH}") +dataset_dict = {} +loaded_splits_count = 0 +for split_name in SPLITS_TO_PROCESS: + split_dir_name = f"{split_name}_dataset" + split_path = os.path.join(FILTERED_ALPACA_PATH, split_dir_name) + print(f"Attempting to load split '{split_name}' from: {split_path}") + try: + # Load dataset to CPU memory + split_dataset = load_from_disk(split_path) + + if not split_dataset: + print(f"Warning: Dataset loaded from '{split_path}' is empty or invalid. Skipping this split.") + continue + dataset_dict[split_name] = split_dataset + print(f"Successfully loaded split '{split_name}' with {len(split_dataset)} examples.") + loaded_splits_count += 1 + except FileNotFoundError: + print(f"Info: Filtered dataset split not found at '{split_path}'. Skipping this split.") + except Exception as e: + print(f"Error loading pre-filtered dataset split from '{split_path}': {e}. Skipping this split.") + +if loaded_splits_count == 0: + print(f"Error: Could not load any dataset splits from '{FILTERED_ALPACA_PATH}' using splits '{SPLITS_TO_PROCESS}'.") + sys.exit(1) + +# 创建输出目录 +os.makedirs(OUTPUT_DATASET_PATH, exist_ok=True) + +# ------------------------ +# 主处理循环 +# ------------------------ +final_dataset_dict = {} + +for split_name, split_dataset in dataset_dict.items(): + print(f"\nProcessing loaded split: {split_name} with {len(split_dataset)} examples") + split_output_dir = os.path.join(OUTPUT_DATASET_PATH, split_name) + os.makedirs(split_output_dir, exist_ok=True) + + progress_file = os.path.join(split_output_dir, "progress.txt") + start_index = 0 + # ... (progress file reading logic remains the same) + if os.path.exists(progress_file): + try: + with open(progress_file, "r") as f: + content = f.read().strip() + if content: + start_index = int(content) + print(f"Resuming split '{split_name}' TTS from sample index {start_index}") + else: + print(f"Progress file '{progress_file}' is empty, starting TTS from index 0.") + start_index = 0 + except ValueError: + print(f"Could not parse integer from progress file '{progress_file}'. Starting TTS from index 0.") + start_index = 0 + except Exception as e: + print(f"Error reading progress file '{progress_file}': {e}. Starting TTS from index 0.") + start_index = 0 + + + final_samples = [] + + pbar = tqdm(range(start_index, len(split_dataset)), desc=f"TTS on '{split_name}' output field", initial=start_index, total=len(split_dataset)) + for i in pbar: + sample = split_dataset[i] # Sample data is on CPU + output_wav_path = os.path.join(split_output_dir, f"{i}.wav") + + if os.path.exists(output_wav_path): + sample_dict = {k: sample[k] for k in sample.keys()} + sample_dict["audio_filepath"] = output_wav_path + final_samples.append(sample_dict) + with open(progress_file, "w") as f: + f.write(str(i + 1)) + continue + + # --- Perform TTS on the target device --- + result = process_example(sample, cosyvoice, common_voice, sample_rate=SAMPLE_RATE) + + if result is not None: + audio_tensor = result['audio_tensor'] # Received tensor is on GPU + sample_rate_val = result['sample_rate'] + + try: + # --- Move tensor to CPU before saving --- + audio_tensor_save = audio_tensor.detach().cpu().to(torch.float32) + if audio_tensor_save.dim() == 1: + audio_tensor_save = audio_tensor_save.unsqueeze(0) + elif audio_tensor_save.dim() > 2: + audio_tensor_save = audio_tensor_save.view(1, -1) + + torchaudio.save(output_wav_path, audio_tensor_save, sample_rate_val) + + sample_dict = {k: sample[k] for k in sample.keys()} + sample_dict["audio_filepath"] = output_wav_path + final_samples.append(sample_dict) + + # --- Explicitly delete GPU tensor and clear cache periodically? --- + # Can sometimes help prevent memory creep in long loops + del audio_tensor + # if i % 50 == 0: # Example: clear cache every 50 iterations + # if torch.cuda.is_available(): torch.cuda.empty_cache() + + except Exception as e: + print(f"Failed to save wav for sample {i} ('output' field TTS) at {output_wav_path}: {e}") + # Clear cache on save error too, just in case + if torch.cuda.is_available(): torch.cuda.empty_cache() + else: + print(f"Sample {i} TTS failed after retries (Output Text: '{sample.get('output', 'N/A')[:60]}...'), no audio generated.") + # No tensor to delete if result is None + + # Update progress file + with open(progress_file, "w") as f: + f.write(str(i + 1)) + + # --- Optional: Add more frequent cache clearing --- + # if i % 20 == 0 and torch.cuda.is_available(): # Clear more often if memory is tight + # torch.cuda.empty_cache() + + + # --- Final cache clear after finishing a split --- + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # ... (Saving final dataset logic remains the same) + if final_samples: + final_dataset_obj = Dataset.from_list(final_samples) + final_dataset_save_path = os.path.join(split_output_dir, "final_dataset") + try: + print(f"Saving final dataset for split '{split_name}' (with new audio paths) to {final_dataset_save_path}...") + os.makedirs(os.path.dirname(final_dataset_save_path), exist_ok=True) + final_dataset_obj.save_to_disk(final_dataset_save_path) + print(f"Finished processing split: {split_name}. Saved {len(final_samples)} samples with new audio paths for 'output' field.") + final_dataset_dict[split_name] = final_dataset_obj + except Exception as e: + print(f"Error saving final dataset for split '{split_name}' to disk: {e}") + else: + print(f"Finished processing split: {split_name}. No samples were successfully processed or saved.") + + +print("="*30) +if final_dataset_dict: + print(f"All specified splits processed. Final datasets saved in respective subdirectories within '{OUTPUT_DATASET_PATH}'.") + print(f"Processed splits: {list(final_dataset_dict.keys())}") +else: + print(f"Processing finished, but no final datasets were generated or saved in '{OUTPUT_DATASET_PATH}'. Check logs for errors.") +print("="*30) \ No newline at end of file diff --git a/r1-a/dataset/commonsense.py b/r1-a/dataset/commonsense.py new file mode 100644 index 0000000000000000000000000000000000000000..3ad8b7f1b6b758d8c1c068a86a4cf303bad2e5af --- /dev/null +++ b/r1-a/dataset/commonsense.py @@ -0,0 +1,175 @@ +import os +import random +import torch +import torchaudio +from datasets import load_dataset, Dataset +import sys +from tqdm import tqdm + +sys.path.append('/root/autodl-tmp/CosyVoice') +from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2 +from cosyvoice.utils.file_utils import load_wav + +# ------------------------ +# 配置参数 +# ------------------------ +COMMON_VOICE_LANGUAGE = "en" +DATASET_NAME = "commonsense_qa" +OUTPUT_DATASET_PATH = './commonsense_qa_with_audio' # 输出目录 +SAMPLE_RATE = 16000 + +# ------------------------ +# 辅助函数 +# ------------------------ +def get_random_prompt(common_voice_dataset, sample_rate=SAMPLE_RATE): + """ + 从 VoxPopuli (此处替代原 Common Voice) 数据集中随机抽取一条语音及对应文本作为 prompt。 + """ + idx = random.randint(0, len(common_voice_dataset) - 1) + sample = common_voice_dataset.select([idx])[0] + audio = sample['audio'] + waveform = torch.tensor(audio['array'], dtype=torch.float32) + sr = audio['sampling_rate'] + if sr != sample_rate: + resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate) + waveform = resampler(waveform) + return waveform.unsqueeze(0), sample['raw_text'] + +def text_to_audio(query_text, cosyvoice, common_voice_dataset, stream=False): + """ + 利用 CosyVoice2 模型将输入文本转换为语音,采用随机 prompt 进行零样本推理。 + """ + try: + prompt_speech, prompt_text = get_random_prompt(common_voice_dataset, sample_rate=SAMPLE_RATE) + + all_speech = [] + for i, j in enumerate(cosyvoice.inference_zero_shot( + query_text, + prompt_text, + prompt_speech, + stream=stream, + text_frontend=False + )): + all_speech.append(j['tts_speech']) + + # 将所有生成的语音片段拼接在一起 + combined_speech = torch.cat(all_speech, dim=-1) + sample_rate_val = cosyvoice.sample_rate + + return { + 'audio_tensor': combined_speech, + 'sample_rate': sample_rate_val + } + except Exception as e: + print(f"Error converting text to audio: {e}") + return None + +def process_example(example, cosyvoice, common_voice_dataset, sample_rate=SAMPLE_RATE): + """ + 针对 Commonsense QA 数据集中的单个样本进行 TTS 处理。 + 在此示例中,仅对 sample['question'] 字段执行 TTS。 + """ + query = example['question'] + audio_result = text_to_audio(query, cosyvoice, common_voice_dataset, stream=False) + if audio_result is not None: + return { + 'audio_tensor': audio_result['audio_tensor'], + 'sample_rate': audio_result['sample_rate'] + } + else: + return None + +# ------------------------ +# 数据加载与模型初始化 +# ------------------------ +print("Loading VoxPopuli (as Common Voice) dataset...") +common_voice = load_dataset("facebook/voxpopuli", "en", split='train') +print(f"Total VoxPopuli {COMMON_VOICE_LANGUAGE} samples: {len(common_voice)}") + +print("Initializing CosyVoice2 model...") +cosyvoice = CosyVoice2( + '/root/autodl-tmp/CosyVoice/pretrained_models/CosyVoice2-0.5B', # 替换为实际模型路径 + load_jit=True, + load_trt=False, + fp16=False +) + +print("Loading Commonsense QA dataset...") +dataset = load_dataset("tau/commonsense_qa") +# 如果只想处理 train,可写成 dataset = load_dataset("tau/commonsense_qa", split="train") + +# 创建输出目录 +os.makedirs(OUTPUT_DATASET_PATH, exist_ok=True) + +# ------------------------ +# 主处理循环 +# ------------------------ +final_dataset_dict = {} # 存放各 split 最终处理后的数据 + +for split_name, split_dataset in dataset.items(): + print(f"Processing split: {split_name} with {len(split_dataset)} examples") + split_output_dir = os.path.join(OUTPUT_DATASET_PATH, split_name) + os.makedirs(split_output_dir, exist_ok=True) + + # 用于断点续跑的进度记录 + progress_file = os.path.join(split_output_dir, "progress.txt") + start_index = 0 + if os.path.exists(progress_file): + try: + with open(progress_file, "r") as f: + start_index = int(f.read().strip()) + print(f"Resuming split '{split_name}' from sample index {start_index}") + except Exception as e: + print(f"读取进度文件失败:{e}") + + final_samples = [] + + # 遍历处理每条样本 + for i in tqdm(range(len(split_dataset)), desc=f"Processing {split_name}"): + # 如果已处理过,就直接跳过并仅把已存在文件的元信息记入 final_samples + if i < start_index: + sample = split_dataset[i] + wav_path = os.path.join(split_output_dir, f"{i}.wav") + if os.path.exists(wav_path): + # 保留所有原始字段 + 音频路径 + sample_dict = {k: sample[k] for k in sample.keys()} + sample_dict["audio_filepath"] = wav_path + final_samples.append(sample_dict) + continue + + sample = split_dataset[i] + result = process_example(sample, cosyvoice, common_voice, sample_rate=SAMPLE_RATE) + + if result is not None: + audio_tensor = result['audio_tensor'] + if audio_tensor.dim() == 1: + audio_tensor = audio_tensor.unsqueeze(0) + sample_rate_val = result['sample_rate'] + + output_wav_path = os.path.join(split_output_dir, f"{i}.wav") + try: + torchaudio.save(output_wav_path, audio_tensor, sample_rate_val) + except Exception as e: + print(f"Failed to save wav for sample {i}: {e}") + continue + + # 保留所有原始字段 + 生成的音频路径 + sample_dict = {k: sample[k] for k in sample.keys()} + sample_dict["audio_filepath"] = output_wav_path + final_samples.append(sample_dict) + else: + print(f"Sample {i} processing failed, no audio generated.") + + # 更新进度记录 + with open(progress_file, "w") as f: + f.write(str(i + 1)) + + # 生成 Hugging Face Dataset 并落盘 + final_dataset_obj = Dataset.from_list(final_samples) + final_dataset_save_path = os.path.join(split_output_dir, "final_dataset") + final_dataset_obj.save_to_disk(final_dataset_save_path) + + print(f"Finished processing split: {split_name} with {len(final_samples)} final samples.") + final_dataset_dict[split_name] = final_dataset_obj + +print("所有分割处理完毕,最终数据集已保存。") diff --git a/r1-a/dataset/examqa.py b/r1-a/dataset/examqa.py new file mode 100644 index 0000000000000000000000000000000000000000..9f4d326f441ea67f4d8a27c95125ef54527a6284 --- /dev/null +++ b/r1-a/dataset/examqa.py @@ -0,0 +1,440 @@ +# --- SET CUDA DEVICE --- +# Method 1: Set environment variable BEFORE importing torch/cosyvoice +# This makes only GPU 1 visible to the script, appearing as 'cuda:0' internally. +import os +os.environ["CUDA_VISIBLE_DEVICES"] = "1" # <-- Keep your original setting +# --- End CUDA Device Setting --- + +import random +import torch +import torchaudio +# Import load_from_disk to load the dataset saved by your LLM script +from datasets import load_dataset, Dataset, load_from_disk, Features, Value, Sequence, ClassLabel # Added Features etc. for robustness +import sys +from tqdm import tqdm +import time +import logging # Add logging +import json # For fallback saving + +# Check if the specified GPU is available after setting the environment variable +if not torch.cuda.is_available(): + print("ERROR: CUDA is not available after setting CUDA_VISIBLE_DEVICES. Cannot run TTS on GPU.") + print("Check your PyTorch installation, GPU drivers, and CUDA setup.") + sys.exit(1) # Exit if GPU is required and not found +else: + # Since CUDA_VISIBLE_DEVICES is set, the first *visible* device is cuda:0 + effective_device = torch.device("cuda:0") + print(f"CUDA device visible to PyTorch: {torch.cuda.get_device_name(0)}") + print(f"Script will effectively run TTS inference on: {effective_device}") + + +sys.path.append('/root/autodl-tmp/CosyVoice') # Make sure this path is correct +# Import CosyVoice *after* setting the environment variable +try: + from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2 + from cosyvoice.utils.file_utils import load_wav +except ImportError as e: + print(f"Error importing CosyVoice: {e}") + print("Please ensure the path '/root/autodl-tmp/CosyVoice' is correct and the library is installed.") + sys.exit(1) + +# Setup basic logging +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') + +# ------------------------ +# 配置参数 +# ------------------------ +COMMON_VOICE_LANGUAGE = "en" +# --- Path to the dataset output by the LLM rephrasing script --- +REPHRASED_DATASET_PATH = './Multi-subject-RLVR_rephrased/train_processed_final' # <-- ADJUST IF YOUR PATH IS DIFFERENT +# --- Output path for THIS TTS script --- +TTS_OUTPUT_PATH = './Multi-subject-RLVR_rephrased_with_audio' # <-- New path for results +SAMPLE_RATE = 16000 +MAX_TTS_RETRIES = 3 +RETRY_DELAY_SECONDS = 2 +# Define the assumed split name for directory structure (even if only one split) +ASSUMED_INPUT_SPLIT = "train" + +# ------------------------ +# 辅助函数 (No changes needed here, includes retry and uses visible GPU) +# ------------------------ +def get_random_prompt(common_voice_dataset, sample_rate=SAMPLE_RATE): + """ + 从 VoxPopuli (替代原 Common Voice) 数据集中随机抽取一条语音及对应文本作为 prompt。 + """ + idx = random.randint(0, len(common_voice_dataset) - 1) + sample = common_voice_dataset.select([idx])[0] + audio = sample['audio'] + waveform = torch.tensor(audio['array'], dtype=torch.float32) # Created on CPU + sr = audio['sampling_rate'] + if sr != sample_rate: + if waveform.dim() > 1: + waveform = waveform.mean(dim=0) + resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate) + waveform = resampler(waveform) + if waveform.dim() == 1: + waveform = waveform.unsqueeze(0) + if waveform.numel() == 0 or not sample['raw_text']: + logging.warning("Got an empty prompt, trying again...") + return get_random_prompt(common_voice_dataset, sample_rate) + # Return CPU tensor, CosyVoice inference should handle moving it + return waveform, sample['raw_text'] + +def text_to_audio(query_text, cosyvoice, common_voice_dataset, stream=False, max_retries=MAX_TTS_RETRIES): + """ + 利用 CosyVoice2 模型将输入文本转换为语音,采用随机 prompt 进行零样本推理。 + Includes retry logic on failure. Assumes cosyvoice runs on the configured device. + """ + last_exception = None + for attempt in range(max_retries): + try: + prompt_speech, prompt_text = get_random_prompt(common_voice_dataset, sample_rate=SAMPLE_RATE) + + all_speech = [] + inference_generator = cosyvoice.inference_zero_shot( + query_text, + prompt_text, + prompt_speech, # Pass CPU tensor + stream=stream, + text_frontend=False + ) + for i, chunk in enumerate(inference_generator): + if 'tts_speech' in chunk and chunk['tts_speech'] is not None: + all_speech.append(chunk['tts_speech']) + else: + logging.warning(f"TTS Chunk {i} missing 'tts_speech' for text '{query_text[:60]}...'") + + if not all_speech: + if torch.cuda.is_available(): torch.cuda.empty_cache() + raise ValueError("TTS inference finished but produced no audio chunks.") + + combined_speech = torch.cat(all_speech, dim=-1) # On GPU + sample_rate_val = cosyvoice.sample_rate + + return { + 'audio_tensor': combined_speech, # Return GPU tensor + 'sample_rate': sample_rate_val + } + except Exception as e: + last_exception = e + logging.error(f"Error converting text to audio on attempt {attempt + 1}/{max_retries}: {e}", exc_info=True) + logging.error(f"Failed Text: '{query_text[:100]}...'") + logging.error(f"Prompt Text Used: '{prompt_text[:100]}...'") + if torch.cuda.is_available(): torch.cuda.empty_cache() + if attempt < max_retries - 1: + wait_time = RETRY_DELAY_SECONDS * (1.5 ** attempt) + random.uniform(0.5, 1.5) + logging.warning(f"Retrying TTS with a different prompt in {wait_time:.2f}s...") + time.sleep(wait_time) + else: + logging.error(f"All {max_retries} TTS attempts failed.") + + logging.error(f"Failed to generate audio for text after {max_retries} attempts: '{query_text[:60]}...'") + logging.error(f"Last TTS error: {last_exception}") + return None + +def process_example(example, cosyvoice, common_voice_dataset, sample_rate=SAMPLE_RATE): + """ + 针对从磁盘加载的 LLM rephrased 数据集中的单个样本进行 TTS 处理。 + Processes example['query_rephrased']. <--- Target the rephrased query + """ + # --- Target the 'query_rephrased' field from the LLM output dataset --- + text_to_convert = example.get('query_rephrased') # <--- Use 'query_rephrased' field + if not text_to_convert or not isinstance(text_to_convert, str) or text_to_convert.strip() == "": + original_query = example.get('query', [{}])[0].get('content', 'Original Query Missing')[:50] + logging.warning(f"Skipping TTS for example due to missing or empty 'query_rephrased' field. Original query started with: '{original_query}...'. Status was: {example.get('query_rephrased_status','N/A')}") + return None + + # --- Use the text_to_audio function with retry logic --- + audio_result = text_to_audio(text_to_convert, cosyvoice, common_voice_dataset, stream=False) + + if audio_result is not None: + audio_tensor = audio_result['audio_tensor'] # Still on GPU here + if audio_tensor.dim() == 1: + audio_tensor = audio_tensor.unsqueeze(0) + elif audio_tensor.dim() > 2: + logging.warning(f"Generated audio tensor has unexpected shape {audio_tensor.shape}. Attempting to flatten.") + audio_tensor = audio_tensor.view(1, -1) + + if audio_tensor.numel() == 0: + logging.warning(f"Generated audio tensor is empty for rephrased query: '{text_to_convert[:60]}...'") + return None + + return { + 'audio_tensor': audio_tensor, # Return GPU tensor + 'sample_rate': audio_result['sample_rate'] + } + else: + # text_to_audio already logged the failure + return None + +# ------------------------ +# 数据加载与模型初始化 +# ------------------------ +logging.info("Loading VoxPopuli (as Common Voice) dataset for prompts...") +try: + common_voice = load_dataset("facebook/voxpopuli", "en", split='train') + logging.info(f"Total VoxPopuli {COMMON_VOICE_LANGUAGE} samples: {len(common_voice)}") + if len(common_voice) == 0: + raise ValueError("VoxPopuli dataset loaded but contains no samples.") +except Exception as e: + logging.error(f"Error loading VoxPopuli dataset: {e}", exc_info=True) + sys.exit(1) + + +logging.info("Initializing CosyVoice2 model...") +try: + cosyvoice = CosyVoice2( + '/root/autodl-tmp/CosyVoice/pretrained_models/CosyVoice2-0.5B', + load_jit=True, + load_trt=False, + fp16=False # Consider setting to True if VRAM is an issue and you have FP16 support + ) + logging.info(f"CosyVoice model initialized on effective device: {effective_device}") +except Exception as e: + logging.error(f"Error initializing CosyVoice2 model: {e}", exc_info=True) + sys.exit(1) + +logging.info(f"Loading rephrased dataset from disk: {REPHRASED_DATASET_PATH}") +try: + # --- Load the single dataset saved by the LLM script --- + rephrased_dataset = load_from_disk(REPHRASED_DATASET_PATH) + if not rephrased_dataset: + raise ValueError(f"Dataset loaded from '{REPHRASED_DATASET_PATH}' is empty or invalid.") + # --- Wrap it in a dict to match the loop structure expecting splits --- + # Use the assumed split name as the key + dataset_dict = {ASSUMED_INPUT_SPLIT: rephrased_dataset} + logging.info(f"Successfully loaded dataset with {len(rephrased_dataset)} examples.") +except FileNotFoundError: + logging.error(f"Error: Rephrased dataset not found at '{REPHRASED_DATASET_PATH}'.") + logging.error("Please ensure the LLM rephrasing script ran successfully and saved data to the correct location.") + sys.exit(1) +except Exception as e: + logging.error(f"Error loading rephrased dataset from '{REPHRASED_DATASET_PATH}': {e}", exc_info=True) + sys.exit(1) + +# 创建输出目录 +os.makedirs(TTS_OUTPUT_PATH, exist_ok=True) + +# ------------------------ +# 主处理循环 +# ------------------------ +final_dataset_dict_for_tracking = {} # To track which final datasets were saved + +# Iterate through the dictionary (will contain only one split, e.g., 'train') +for split_name, split_dataset in dataset_dict.items(): + logging.info(f"\nProcessing split: {split_name} with {len(split_dataset)} examples for TTS") + # Output directory for *this* script's results (audio + final dataset) + split_output_dir = os.path.join(TTS_OUTPUT_PATH, split_name) + os.makedirs(split_output_dir, exist_ok=True) + logging.info(f"Audio files and final data for this split will be saved in: {split_output_dir}") + + # 用于断点续跑的进度记录 (specific to this TTS process) + progress_file = os.path.join(split_output_dir, "tts_progress.txt") # Use specific name + start_index = 0 + if os.path.exists(progress_file): + try: + with open(progress_file, "r") as f: + content = f.read().strip() + if content: + start_index = int(content) + logging.info(f"Resuming split '{split_name}' TTS from sample index {start_index}") + else: + logging.info(f"Progress file '{progress_file}' is empty, starting TTS from index 0.") + start_index = 0 + except ValueError: + logging.warning(f"Could not parse integer from progress file '{progress_file}'. Starting TTS from index 0.") + start_index = 0 + except Exception as e: + logging.error(f"Error reading progress file '{progress_file}': {e}. Starting TTS from index 0.") + start_index = 0 + else: + logging.info(f"No progress file found at '{progress_file}'. Starting TTS from index 0.") + + + # --- [NEW] Section: Check and Save Already Completed Samples --- + already_processed_samples = [] + if start_index > 0: + logging.info(f"Checking for already processed samples (audio files) up to index {start_index - 1}...") + # Use tqdm here for visibility if start_index is large + for j in tqdm(range(start_index), desc="Checking existing audio"): + potential_output_wav_path = os.path.join(split_output_dir, f"{j}.wav") + if os.path.exists(potential_output_wav_path): + try: + # Ensure index is valid before accessing + if j < len(split_dataset): + original_sample = split_dataset[j] + # Create a dict with all original keys + the existing audio path + completed_sample_dict = {k: original_sample[k] for k in original_sample.keys()} + completed_sample_dict["audio_filepath"] = potential_output_wav_path # Point to the existing audio + already_processed_samples.append(completed_sample_dict) + else: + logging.warning(f"Index {j} is out of bounds for the loaded dataset (size {len(split_dataset)}) while checking existing files. Skipping this index.") + except Exception as e: + logging.error(f"Error processing data for existing sample index {j}: {e}") + + if already_processed_samples: + logging.info(f"Found {len(already_processed_samples)} samples with existing audio files before the resume point.") + # Define path for the dataset of already processed samples + already_processed_dataset_path = os.path.join(split_output_dir, "already_processed_dataset") + try: + logging.info(f"Saving these {len(already_processed_samples)} already processed samples to: {already_processed_dataset_path}") + + # Define features based on the original dataset + the new audio_filepath column + original_features = split_dataset.features + new_features_dict = original_features.copy() + if "audio_filepath" not in new_features_dict: + new_features_dict["audio_filepath"] = Value('string') + new_features = Features(new_features_dict) + + already_processed_dataset = Dataset.from_list(already_processed_samples, features=new_features) + already_processed_dataset.save_to_disk(already_processed_dataset_path) + logging.info(f"Successfully saved dataset of {len(already_processed_samples)} already processed samples.") + # Clear the list to free memory + del already_processed_samples + if torch.cuda.is_available(): torch.cuda.empty_cache() # Clear cache after potential large list processing + except Exception as e: + logging.error(f"Failed to create or save the dataset of already processed samples: {e}", exc_info=True) + # Keep already_processed_samples list in memory in case of save failure? Maybe not needed. + else: + logging.info("No existing audio files found for samples before the resume point (index 0 to {}).".format(start_index - 1)) + # --- [END NEW] Section --- + + + # --- Main processing loop --- + final_samples = [] # List to hold ALL processed sample dictionaries for the FINAL dataset of this run + logging.info(f"Starting/Resuming TTS processing from index {start_index}...") + pbar = tqdm(range(start_index, len(split_dataset)), desc=f"TTS on '{split_name}' query_rephrased", initial=start_index, total=len(split_dataset)) + for i in pbar: + try: + sample = split_dataset[i] # Sample data is on CPU + except IndexError: + logging.error(f"Index {i} is out of bounds for split_dataset (size {len(split_dataset)}). Stopping processing.") + break # Stop if we somehow go out of bounds + + # Define path for the *new* audio file (or potentially existing one) + output_wav_path = os.path.join(split_output_dir, f"{i}.wav") + + # Check if this specific TTS output file ALREADY exists + # This handles cases where the script stopped AFTER saving audio but BEFORE updating progress + # OR if files were somehow generated but progress file was lost/reset. + if os.path.exists(output_wav_path): + logging.debug(f"Audio file already exists for index {i} at {output_wav_path}. Skipping TTS, adding to final list.") + # Create the dict with all original keys + the existing audio path + sample_dict = {k: sample[k] for k in sample.keys()} + sample_dict["audio_filepath"] = output_wav_path # Point to the existing audio + final_samples.append(sample_dict) # Add to the list for the *final* dataset + # Update progress even if skipped due to existing file (important!) + with open(progress_file, "w") as f: + f.write(str(i + 1)) + continue # Move to the next sample + + # --- Perform TTS on the 'query_rephrased' field (UNCHANGED CORE LOGIC) --- + result = process_example(sample, cosyvoice, common_voice, sample_rate=SAMPLE_RATE) + + if result is not None: + audio_tensor = result['audio_tensor'] # Received tensor is on GPU + sample_rate_val = result['sample_rate'] + + try: + # --- Move tensor to CPU before saving --- + audio_tensor_save = audio_tensor.detach().cpu().to(torch.float32) + if audio_tensor_save.dim() == 1: + audio_tensor_save = audio_tensor_save.unsqueeze(0) + elif audio_tensor_save.dim() > 2: + audio_tensor_save = audio_tensor_save.view(1, -1) + + torchaudio.save(output_wav_path, audio_tensor_save, sample_rate_val) + + # --- Preserve all original fields + add the NEW audio path --- + sample_dict = {k: sample[k] for k in sample.keys()} + sample_dict["audio_filepath"] = output_wav_path # Add the path to the new audio + final_samples.append(sample_dict) # Add to the list for the *final* dataset + + # Explicitly delete GPU tensor and clear cache + del audio_tensor + del audio_tensor_save + if torch.cuda.is_available(): torch.cuda.empty_cache() + + + except Exception as e: + logging.error(f"Failed to save wav for sample {i} (TTS of query_rephrased) at {output_wav_path}: {e}", exc_info=True) + if torch.cuda.is_available(): torch.cuda.empty_cache() # Clear cache on save error too + else: + # process_example already logged the failure + logging.warning(f"Sample {i} TTS failed after retries (Rephrased Query: '{str(sample.get('query_rephrased', 'N/A'))[:60]}...'), no audio generated. This sample will NOT be included in the final dataset.") + # Decide whether to add sample without audio path or skip it + # Skipping for now, as audio is the goal. If you wanted to include failed ones: + # sample_dict = {k: sample[k] for k in sample.keys()} + # sample_dict["audio_filepath"] = None # Indicate missing audio + # final_samples.append(sample_dict) + + + # Update progress file after processing each sample (success or failure to ensure resume point advances) + # Make sure this is outside the 'if result is not None' block + with open(progress_file, "w") as f: + f.write(str(i + 1)) + + # Optional: More frequent cache clearing (Uncomment if needed) + # if i % 50 == 0 and torch.cuda.is_available(): + # torch.cuda.empty_cache() + # logging.debug(f"Cleared CUDA cache at index {i}") + + # --- Final cache clear after finishing the split --- + if torch.cuda.is_available(): + logging.info("Clearing final CUDA cache for the split.") + torch.cuda.empty_cache() + + # --- Save the final dataset object for this split (contains items found + newly generated) --- + if final_samples: + # Define features based on the original dataset + the new audio_filepath column for the final dataset + final_features_dict = split_dataset.features.copy() + if "audio_filepath" not in final_features_dict: + final_features_dict["audio_filepath"] = Value('string') + final_features = Features(final_features_dict) + + try: + logging.info(f"Attempting to create final dataset object from {len(final_samples)} collected samples...") + final_dataset_obj = Dataset.from_list(final_samples, features=final_features) + + # Save the final dataset object inside the split's output directory + final_dataset_save_path = os.path.join(split_output_dir, "final_dataset") # Name for the complete dataset + + logging.info(f"Saving final dataset object for split '{split_name}' (with audio paths) to {final_dataset_save_path}...") + os.makedirs(os.path.dirname(final_dataset_save_path), exist_ok=True) # Should exist, but safety check + final_dataset_obj.save_to_disk(final_dataset_save_path) + logging.info(f"Finished processing split: {split_name}. Saved {len(final_samples)} samples in the final dataset at '{final_dataset_save_path}'.") + final_dataset_dict_for_tracking[split_name] = final_dataset_obj # Keep track if needed + + except Exception as e: + logging.error(f"Error creating or saving final dataset object for split '{split_name}': {e}", exc_info=True) + logging.error("Attempting to save final_samples list as JSON Lines as a fallback...") + fallback_path = os.path.join(split_output_dir, "final_samples_fallback.jsonl") # Use distinct fallback name + try: + with open(fallback_path, 'w', encoding='utf-8') as f: + for item in final_samples: + # Basic serialization attempt + serializable_item = {} + for k, v in item.items(): + if isinstance(v, torch.Tensor): + serializable_item[k] = f"Tensor data (shape: {v.shape})" # Placeholder + elif isinstance(v, (dict, list, str, int, float, bool, type(None))): + serializable_item[k] = v + else: + serializable_item[k] = str(v) # Attempt string conversion for others + f.write(json.dumps(serializable_item) + '\n') + logging.info(f"Fallback JSON Lines saved to {fallback_path}") + except Exception as json_e: + logging.error(f"Fallback JSON save failed: {json_e}", exc_info=True) + + else: + logging.warning(f"Finished processing split: {split_name}. No samples were successfully processed or found with existing audio during this run to add to the final dataset.") + + +print("="*30) +if final_dataset_dict_for_tracking: + logging.info(f"All specified splits processed for TTS. Final datasets saved in respective 'final_dataset' subdirectories within '{TTS_OUTPUT_PATH}'.") + logging.info(f"Processed splits where final datasets were generated: {list(final_dataset_dict_for_tracking.keys())}") + logging.info("Additionally, if resuming, datasets containing only the samples processed *before* this run may have been saved in 'already_processed_dataset' subdirectories.") +else: + logging.warning(f"TTS processing finished, but no final datasets were generated or saved in the 'final_dataset' subdirectories within '{TTS_OUTPUT_PATH}'. Check logs for errors. Pre-existing data might be in 'already_processed_dataset' if resuming.") +print("="*30) \ No newline at end of file diff --git a/r1-a/dataset/examqa_rewrite.py b/r1-a/dataset/examqa_rewrite.py new file mode 100644 index 0000000000000000000000000000000000000000..3c964fac0cce28a5e63af2d97b516094be0db31e --- /dev/null +++ b/r1-a/dataset/examqa_rewrite.py @@ -0,0 +1,487 @@ +import os +import http.client +import json +import time +import random +from datasets import load_dataset, Dataset, DatasetDict, Features, Value +from tqdm.auto import tqdm +import sys +import logging +import getpass +import signal +import socket +import concurrent.futures +from concurrent.futures import ThreadPoolExecutor +import argparse # For command-line arguments + +# --- Configuration --- (Mostly Same) +DATASET_NAME = "virtuoussy/Multi-subject-RLVR" +DATASET_SPLIT = "train" +API_HOST = "api2.aigcbest.top" +API_PATH = "/v1/chat/completions" +LLM_MODEL = "gpt-4.1-mini" +API_KEY = os.environ.get('AIGCBEST_API_KEY', "sk-U15cDXxI0bboL6iH4Hymzl30ws6oWzazWe1Ndwq9QtiPUEgI") # Simplified API Key Get +if not API_KEY or API_KEY == "YOUR_API_KEY_HERE": + print("API Key is not set correctly. Please set the AIGCBEST_API_KEY environment variable or replace the placeholder.") + sys.exit(1) + +OUTPUT_DIR = f"./{DATASET_NAME.split('/')[-1]}_rephrased" +# Define the path where the *potentially incomplete* processed dataset exists +PROCESSED_DATA_PATH = os.path.join(OUTPUT_DIR, f"{DATASET_SPLIT}_processed") +# Define where the *final, retried* dataset will be saved +FINAL_OUTPUT_PATH = os.path.join(OUTPUT_DIR, f"{DATASET_SPLIT}_processed_final") # Save to new location initially + +BATCH_SAVE_SIZE = 500 # How often to save intermediate progress *during retry* +MAX_WORKERS = 20 +REQUEST_DELAY_SECONDS = 0.15 +MAX_RETRIES = 3 + +# Setup logging +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logging.getLogger("datasets").setLevel(logging.WARNING) +logging.getLogger("huggingface_hub").setLevel(logging.WARNING) + +# --- LLM API Function (call_llm_api) --- +# Use the robust version from the previous answer +def call_llm_api(original_question, api_key, host, path, model, retries=MAX_RETRIES): + system_prompt = ( + "You are an expert linguist specializing in converting structured prompts or " + "fill-in-the-blank problems into natural, spoken-language questions suitable for " + "text-to-speech (TTS). Your goal is to make the question sound like how a person " + "would naturally ask it. " + "If the input is a fill-in-the-blank problem (e.g., contains '-----'), " + "rephrase it as a direct question asking for the missing information. " + "Keep the core meaning, mathematical context, variables, and numbers exactly the same. " + "Focus only on rephrasing the *user's question* part provided. " + "Output *only* the rephrased question, without any introductory phrases like 'Here's the rephrased question:'." + ) + payload = json.dumps({ + "model": model, + "messages": [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": original_question} + ], + }) + headers = { + 'Accept': 'application/json', + 'Authorization': f'Bearer {api_key}', + 'User-Agent': 'HuggingFace Dataset Processing Script (Retry Mode)', + 'Content-Type': 'application/json' + } + time.sleep(random.uniform(REQUEST_DELAY_SECONDS * 0.8, REQUEST_DELAY_SECONDS * 1.2)) + + for attempt in range(retries): + logging.debug(f"API call attempt {attempt + 1}/{retries} for: {original_question[:50]}...") + try: + conn = http.client.HTTPSConnection(host, timeout=60) + conn.request("POST", path, payload, headers) + res = conn.getresponse() + status = res.status + data = res.read() + conn.close() + + if status == 200: + response_json = json.loads(data.decode("utf-8")) + if response_json.get("choices") and len(response_json["choices"]) > 0: + message = response_json["choices"][0].get("message") + if message and message.get("content"): + rephrased = message["content"].strip() + if len(rephrased) > 1 and ((rephrased.startswith('"') and rephrased.endswith('"')) or \ + (rephrased.startswith("'") and rephrased.endswith("'"))): + rephrased = rephrased[1:-1] + if rephrased and rephrased.strip().lower() != original_question.strip().lower(): + logging.debug(f"Successfully rephrased: {rephrased[:80]}...") + return rephrased + elif not rephrased: + logging.warning(f"LLM returned empty response for: {original_question[:50]}...") + return None + else: + logging.warning(f"LLM returned identical response for: {original_question[:50]}...") + return None # Treat identical as failure for rephrasing + logging.error(f"Unexpected API response structure: {data.decode('utf-8')}") + return None + elif status == 429: + retry_after_header = res.getheader('Retry-After', '5') + try: wait_time = int(retry_after_header) + except ValueError: wait_time = REQUEST_DELAY_SECONDS * (2 ** attempt) + random.uniform(1, 5) + logging.warning(f"Rate limit exceeded (HTTP {status}). Retrying after {wait_time:.2f} seconds...") + time.sleep(wait_time) + elif status >= 500: + wait_time = REQUEST_DELAY_SECONDS * (2 ** attempt) + random.uniform(1, 5) + logging.warning(f"Server error (HTTP {status}). Retrying after {wait_time:.2f} seconds...") + time.sleep(wait_time) + else: + logging.error(f"API Client Error: Status {status}, Response: {data.decode('utf-8')}") + return None + except (http.client.HTTPException, ConnectionError, socket.gaierror, TimeoutError, socket.timeout) as e: + logging.error(f"Network/HTTP error during API call: {e}. Attempt {attempt + 1}/{retries}") + if attempt + 1 == retries: return None + wait_time = REQUEST_DELAY_SECONDS * (1.5 ** attempt) + random.uniform(1, 3) + logging.warning(f"Waiting {wait_time:.2f} seconds before retry...") + time.sleep(wait_time) + except json.JSONDecodeError as e: + logging.error(f"Failed to decode API response: {e}. Response snippet: {data[:200]}") + if attempt + 1 == retries: return None + wait_time = REQUEST_DELAY_SECONDS * (2 ** attempt) + random.uniform(1, 5) + time.sleep(wait_time) # Wait before next attempt + except Exception as e: + logging.error(f"An unexpected error occurred during API call: {e}", exc_info=True) + if attempt + 1 == retries: return None + wait_time = REQUEST_DELAY_SECONDS * (1.5 ** attempt) + random.uniform(1, 3) + logging.warning(f"Waiting {wait_time:.2f} seconds before retry...") + time.sleep(wait_time) + + logging.error(f"API call failed after {retries} retries for: {original_question[:50]}...") + return None + + +# --- Dataset Processing Function (rephrase_query_entry) --- +# Same as before, returns the full dictionary with status +def rephrase_query_entry(example): + processed_example = example.copy() + # Ensure status field exists, default to unprocessed if missing + if 'query_rephrased_status' not in processed_example: + processed_example['query_rephrased_status'] = 'unprocessed' + + original_query_list = example.get("query") + + # --- Input Validation --- + if original_query_list is None: + processed_example['query_rephrased_status'] = 'skipped_missing_query_column' + processed_example['query_rephrased'] = None + return processed_example + if not isinstance(original_query_list, list): + processed_example['query_rephrased_status'] = 'skipped_query_not_list' + processed_example['query_rephrased'] = None + return processed_example + if not original_query_list: + processed_example['query_rephrased_status'] = 'skipped_query_list_empty' + processed_example['query_rephrased'] = None + return processed_example + + # --- Find User Question --- + user_question = None + for i, message in enumerate(original_query_list): + if isinstance(message, dict) and message.get("role") == "user": + content = message.get("content") + if isinstance(content, str) and content.strip(): + user_question = content + break + else: + processed_example['query_rephrased_status'] = 'skipped_invalid_user_content' + processed_example['query_rephrased'] = None + return processed_example + + if not user_question: + processed_example['query_rephrased_status'] = 'skipped_no_user_content_found' + processed_example['query_rephrased'] = None + return processed_example + + # --- Call LLM API --- + logging.info(f"Attempting to rephrase: {user_question[:60]}...") # Log retry attempt + rephrased_query_content = call_llm_api(user_question, API_KEY, API_HOST, API_PATH, LLM_MODEL) + + # --- Update Example Based on API Result --- + if rephrased_query_content: + logging.debug(f"Rephrased '{user_question[:30]}...' to '{rephrased_query_content[:30]}...'") + processed_example["query_rephrased"] = rephrased_query_content + processed_example['query_rephrased_status'] = 'success_retried' # New status for successful retry + else: + logging.warning(f"Retry failed for user question: {user_question[:50]}...") + # Keep existing rephrased content (likely None) but update status + processed_example['query_rephrased_status'] = 'failed_llm_retry' # New status for failed retry + + return processed_example + + +# --- Function to Save Progress --- +# Saves the *entire list* of dictionaries +def save_final_dataset(data_list, output_path): + """Saves the final list of processed data dictionaries.""" + if not data_list: + logging.info("No data provided for saving.") + return False + logging.info(f"Attempting to save {len(data_list)} final examples to {output_path}...") + try: + # Define features explicitly to handle potential Nones and ensure consistency + # Adjust types based on your actual dataset structure + features = Features({ + 'query': [{'role': Value(dtype='string', id=None), 'content': Value(dtype='string', id=None)}], + 'query_rephrased': Value(dtype='string', id=None), # Allow nulls + 'query_rephrased_status': Value(dtype='string', id=None), # Allow nulls + # Add other columns from your original dataset here... + # Example: 'answer': Value(dtype='string', id=None), + # Example: 'subject': Value(dtype='string', id=None), + # IMPORTANT: List all columns present in your loaded dataset + 'query_code': Value(dtype='string', id=None), + 'answer': Value(dtype='string', id=None), + 'answer_code': Value(dtype='string', id=None), + 'subject': Value(dtype='string', id=None), + 'grade': Value(dtype='string', id=None), + 'source': Value(dtype='string', id=None), + 'split': Value(dtype='string', id=None), + '__index_level_0__': Value(dtype='int64', id=None) # Check if this column exists + }) + + # Clean data slightly - replace python None with "" for string fields if needed by Arrow + # or ensure feature definition handles nulls correctly (Value(dtype='string', id=None) should) + # cleaned_data_list = [] + # for item in data_list: + # cleaned_item = item.copy() + # for key, feature_type in features.items(): + # if isinstance(feature_type, Value) and feature_type.dtype == 'string': + # if cleaned_item.get(key) is None: + # cleaned_item[key] = "" # Or keep None if schema allows + # cleaned_data_list.append(cleaned_item) + + + # Use the original list directly if schema handles None + processed_dataset = Dataset.from_list(list(data_list), features=features) + os.makedirs(os.path.dirname(output_path), exist_ok=True) + processed_dataset.save_to_disk(output_path) + logging.info(f"Successfully saved final dataset ({len(data_list)} items) to {output_path}") + return True + except Exception as e: + logging.error(f"Failed to save final dataset to {output_path}: {e}", exc_info=True) + # Try saving as JSON as a fallback + fallback_json_path = output_path + ".jsonl" + logging.warning(f"Attempting fallback save to JSON Lines file: {fallback_json_path}") + try: + with open(fallback_json_path, 'w', encoding='utf-8') as f: + for item in data_list: + f.write(json.dumps(item, ensure_ascii=False) + '\n') + logging.info(f"Successfully saved fallback JSON Lines file to {fallback_json_path}") + except Exception as json_e: + logging.error(f"Fallback JSON save also failed: {json_e}", exc_info=True) + return False + + +# --- Helper to get original user query --- +def get_user_query(example): + """Extracts the user query content from the 'query' list.""" + query_list = example.get("query") + if isinstance(query_list, list): + for message in query_list: + if isinstance(message, dict) and message.get("role") == "user": + content = message.get("content") + if isinstance(content, str) and content.strip(): + return content + return None + + +# --- Function to Check if Retry is Needed --- +def needs_retry(example): + """Determines if an example needs reprocessing based on its current state.""" + status = example.get('query_rephrased_status') + rephrased_text = example.get('query_rephrased') + + # Condition 1: Explicit failure status from previous (new script) run + if status in ['failed_llm_call', 'failed_llm_retry', 'failed_processing_exception']: + return True + + # Condition 2: Certain 'skipped' statuses might warrant a retry (optional, adjust as needed) + # For example, if the user content was invalid originally, retrying won't help. + # if status in ['skipped_no_user_content_found']: # Decide if these should be retried + # return True + + # Condition 3: Status indicates success OR status is missing/old, + # BUT the rephrased text is missing or empty. This catches failures + # from the *old* script or inconsistent states. + if rephrased_text is None or not str(rephrased_text).strip(): + # Don't retry if it was intentionally skipped due to bad input + if status not in ['skipped_missing_query_column', 'skipped_query_not_list', + 'skipped_query_list_empty', 'skipped_invalid_user_content', + 'skipped_no_user_content_found']: + return True + + # Condition 4 (Optional but recommended): Check if rephrased text is identical to original user query + # This requires extracting the original query here. + # original_user_query = get_user_query(example) + # if original_user_query and isinstance(rephrased_text, str) and \ + # rephrased_text.strip().lower() == original_user_query.strip().lower(): + # # Check status first - if it was intentionally skipped, don't retry + # if status not in ['skipped_missing_query_column', 'skipped_query_not_list', + # 'skipped_query_list_empty', 'skipped_invalid_user_content', + # 'skipped_no_user_content_found']: + # logging.debug(f"Identified identical query/rephrased text for retry: {original_user_query[:50]}...") + # return True + + + # Default: No retry needed + return False + + +# --- Main Execution --- +if __name__ == "__main__": + start_time = time.time() + logging.info("======================================================") + logging.info(f" Starting Dataset Processing Script in RETRY MODE") + logging.info("======================================================") + logging.info(f"Dataset: {DATASET_NAME}, Split: {DATASET_SPLIT}") + logging.info(f"Loading existing processed data from: {PROCESSED_DATA_PATH}") + logging.info(f"Final output will be saved to: {FINAL_OUTPUT_PATH}") + logging.info(f"Max concurrent workers: {MAX_WORKERS}") + + # --- Load Existing Processed Dataset --- + if not os.path.exists(PROCESSED_DATA_PATH): + logging.error(f"Existing processed data not found at '{PROCESSED_DATA_PATH}'. Cannot run in retry mode.") + sys.exit(1) + + logging.info(f"Loading existing dataset from {PROCESSED_DATA_PATH}...") + try: + # Load the dataset saved by the previous script run + existing_dataset = Dataset.load_from_disk(PROCESSED_DATA_PATH) + # Convert to list of dictionaries for easier modification access by index + # Be mindful of memory usage for very large datasets + results_list = existing_dataset.to_list() + total_examples = len(results_list) + logging.info(f"Loaded {total_examples} examples.") + # Ensure essential columns exist, add them if missing from old format + for i in range(total_examples): + if 'query_rephrased' not in results_list[i]: + results_list[i]['query_rephrased'] = None + if 'query_rephrased_status' not in results_list[i]: + results_list[i]['query_rephrased_status'] = 'unknown_original_status' + + except Exception as e: + logging.error(f"Failed to load existing dataset from {PROCESSED_DATA_PATH}: {e}", exc_info=True) + sys.exit(1) + + + # --- Identify Indices to Retry --- + indices_to_retry = [ + i for i, example in enumerate(results_list) if needs_retry(example) + ] + num_to_retry = len(indices_to_retry) + + if num_to_retry == 0: + logging.info("No examples found needing retry based on the criteria.") + logging.info(f"The dataset at {PROCESSED_DATA_PATH} is considered final.") + # Optional: You might still want to save it to FINAL_OUTPUT_PATH for consistency + # if not os.path.exists(FINAL_OUTPUT_PATH): + # save_final_dataset(results_list, FINAL_OUTPUT_PATH) + sys.exit(0) + + logging.info(f"Identified {num_to_retry} examples to retry out of {total_examples}.") + + # --- Prepare for Concurrent Retries --- + processed_count_in_retry = 0 + # We don't need batch saving in the same way, but can update the list in memory + # A temporary dictionary to store results from futures before updating the main list + retry_results_dict = {} + + logging.info("Starting concurrent processing for examples needing retry...") + + with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor: + # Submit jobs only for the indices needing retry + # Pass the *specific example dictionary* to the function + futures = { + executor.submit(rephrase_query_entry, results_list[i]): i + for i in indices_to_retry + } + + try: + pbar = tqdm(total=num_to_retry, desc="Retrying failed examples", unit="example") + for future in concurrent.futures.as_completed(futures): + original_index = futures[future] # Get the index in the full results_list + try: + # Get the updated dictionary result from the retry attempt + updated_example_dict = future.result() + # Store the result temporarily, keyed by original index + retry_results_dict[original_index] = updated_example_dict + pbar.set_postfix({"LastStatus": updated_example_dict.get('query_rephrased_status', 'N/A')}, refresh=True) + + except Exception as exc: + # Catch errors *during* the retry processing itself + logging.error(f'Retry for example index {original_index} generated an exception: {exc}', exc_info=True) + # Create a placeholder indicating the retry attempt failed due to an exception + error_placeholder = results_list[original_index].copy() # Get original data again + error_placeholder['query_rephrased_status'] = f'failed_retry_exception_{type(exc).__name__}' + # Store this error placeholder + retry_results_dict[original_index] = error_placeholder + pbar.set_postfix({"LastStatus": error_placeholder['query_rephrased_status']}, refresh=True) + + finally: + processed_count_in_retry += 1 + pbar.update(1) + # Optional intermediate save logic (maybe save every N retries) + # Could save the *entire* potentially partially updated list, but might be slow. + # if processed_count_in_retry % BATCH_SAVE_SIZE == 0: + # logging.info(f"Processed {processed_count_in_retry} retries, updating intermediate state...") + # # Update the main list with results gathered so far + # for idx, updated_item in retry_results_dict.items(): + # results_list[idx] = updated_item + # # Clear the temporary dict after updating + # retry_results_dict.clear() + # # Save the whole list (potentially slow) + # save_final_dataset(results_list, FINAL_OUTPUT_PATH + "_interim") + + + except KeyboardInterrupt: + logging.warning("\nCtrl+C detected during retry! Attempting to save progress...") + + except Exception as e: + logging.error(f"An unexpected error occurred during the retry loop: {e}", exc_info=True) + + finally: + if 'pbar' in locals(): + pbar.close() + + # --- Update the main results list with all completed retries --- + logging.info("Updating main results list with completed retry attempts...") + update_count = 0 + for idx, updated_item in retry_results_dict.items(): + if idx < len(results_list): + results_list[idx] = updated_item + update_count += 1 + else: + logging.error(f"Index {idx} from retry results is out of bounds for results_list (size {len(results_list)}). Skipping update.") + + logging.info(f"Applied updates for {update_count} retried items.") + + # --- Final Save --- + logging.info(f"Attempting to save the final updated dataset to: {FINAL_OUTPUT_PATH}") + if save_final_dataset(results_list, FINAL_OUTPUT_PATH): + logging.info("Final dataset saved successfully.") + # Optional: Suggest deleting the old intermediate path if successful + # logging.info(f"You may now safely remove the intermediate directory: {PROCESSED_DATA_PATH}") + else: + logging.error(">>> FINAL SAVE FAILED! <<<") + logging.error(f"Check the logs. The latest state might be in memory or a fallback JSON file if created.") + + + # --- Final Verification (Optional) --- + logging.info("------------------------------------------------------") + logging.info("Verification: Loading final saved dataset for status check...") + try: + final_reloaded_dataset = Dataset.load_from_disk(FINAL_OUTPUT_PATH) + logging.info(f"Successfully reloaded final dataset with {len(final_reloaded_dataset)} examples from {FINAL_OUTPUT_PATH}.") + status_counts = {} + for ex in final_reloaded_dataset: + status = ex.get('query_rephrased_status', 'unknown_status_field') + status_counts[status] = status_counts.get(status, 0) + 1 + + logging.info("Status counts in the final saved file:") + for status, count in sorted(status_counts.items()): + logging.info(f" - {status}: {count}") + + # Highlight remaining failures + remaining_failures = status_counts.get('failed_llm_retry', 0) + \ + status_counts.get('failed_retry_exception', 0) + \ + status_counts.get('failed_llm_call', 0) # Include original failures if not retried/still failing + + if remaining_failures > 0: + logging.warning(f"Found {remaining_failures} examples still marked as failed after retry attempts.") + else: + logging.info("All identified failures appear to have been successfully retried or were not retried.") + + except FileNotFoundError: + logging.error(f"Verification failed: Final saved dataset not found at {FINAL_OUTPUT_PATH}.") + except Exception as e: + logging.error(f"Failed to reload or verify final dataset from {FINAL_OUTPUT_PATH}: {e}", exc_info=True) + + + end_time = time.time() + logging.info("------------------------------------------------------") + logging.info(f"Retry script finished in {end_time - start_time:.2f} seconds.") + logging.info("======================================================") \ No newline at end of file diff --git a/r1-a/dataset/final_tts.py b/r1-a/dataset/final_tts.py new file mode 100644 index 0000000000000000000000000000000000000000..835943063c0a4698f09f11a39ab16c929cc0799a --- /dev/null +++ b/r1-a/dataset/final_tts.py @@ -0,0 +1,316 @@ +# --- ENVIRONMENT VARIABLE CONTROL --- +import os +import sys +import random +import torch +import torchaudio +from datasets import load_dataset, Dataset, load_from_disk, Features, Value, Audio +from tqdm import tqdm +import time +import logging +import json +import math +import pathlib +import re +import unicodedata + +# --- Read Environment Variables --- +try: + # GPU ID for this specific run (0, 1, 2, or 3) + PROCESSING_GPU_ID = 3 + # Total number of parallel runs (should be 4) + TOTAL_PROCESSING_NODES = 4 +except ValueError: + print("Error: PROCESSING_GPU_ID and TOTAL_PROCESSING_NODES env vars must be integers.") + sys.exit(1) + +if not 0 <= PROCESSING_GPU_ID < TOTAL_PROCESSING_NODES: + print(f"Error: PROCESSING_GPU_ID ({PROCESSING_GPU_ID}) must be between 0 and {TOTAL_PROCESSING_NODES - 1}.") + sys.exit(1) + +print(f"--- Starting Run for Shard {PROCESSING_GPU_ID + 1} / {TOTAL_PROCESSING_NODES} ---") +print(f"--- Targetting GPU Index (physical): {PROCESSING_GPU_ID} ---") + +# --- SET VISIBLE CUDA DEVICE *BEFORE* TORCH IMPORT THAT USES CUDA --- +# This makes the chosen GPU appear as 'cuda:0' to this script instance +os.environ["CUDA_VISIBLE_DEVICES"] = str(PROCESSING_GPU_ID) + +# Check CUDA availability *after* setting visibility +if not torch.cuda.is_available(): + print(f"ERROR: CUDA device {PROCESSING_GPU_ID} is not available after setting CUDA_VISIBLE_DEVICES.") + sys.exit(1) +else: + # PyTorch now sees the selected GPU as cuda:0 + effective_device = torch.device("cuda:0") + try: + print(f"Script process {os.getpid()} successfully assigned to specific GPU: {torch.cuda.get_device_name(0)} (Original Index {PROCESSING_GPU_ID})") + except Exception as e: + print(f"Warning: Could not get device name, but CUDA is available. Error: {e}") + print(f"Script process {os.getpid()} assigned to specific GPU index {PROCESSING_GPU_ID}") + + +# --- Add CosyVoice Path --- +COSYVOICE_PATH = '/home/chenyifu/CosyVoice' # <-- Your path +if COSYVOICE_PATH not in sys.path: + sys.path.append(COSYVOICE_PATH) + +# Import CosyVoice +try: + from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2 + from cosyvoice.utils.file_utils import load_wav +except ImportError as e: + print(f"Error importing CosyVoice: {e}") + sys.exit(1) + +# Setup basic logging for this instance +logging.basicConfig(level=logging.INFO, format=f'%(asctime)s - %(levelname)s - [GPU-{PROCESSING_GPU_ID}] %(message)s') + +# ------------------------ +# 配置参数 (Configuration Parameters) - Mostly unchanged +# ------------------------ +# --- Input Dataset --- +INPUT_DATASET_PATH = '/home/chenyifu/audio-r1/r1-a/dataset/prompt_only' +TEXT_FIELD_FOR_TTS = "question_text" +AUDIO_PATH_FIELD = "question_audio" # Field name for the *final* aggregated dataset +ASSUMED_INPUT_SPLIT_NAME = "train" + +# --- Output --- +TTS_OUTPUT_BASE_PATH = '/home/chenyifu/audio-r1/r1-a/dataset/prompt_only_fully_merged_with_audio' # <<-- SHARED output path for all runs +AUDIO_SUBFOLDER_NAME = 'audio_files' # <<-- SHARED audio subfolder + +# --- Prompt Audio Settings --- +RAW_PROMPT_DATASET_PATH = "/home/chenyifu/audio-r1/r1-a/dataset/mls_eng10k" +PROMPT_DATASET_SPLIT = "train" +PROMPT_MIN_DURATION_S = 10 +PROMPT_MAX_DURATION_S = 13 +PROMPT_TEXT_FIELD = "transcript" +PROMPT_AUDIO_DURATION_FIELD = "audio_duration" +FILTERED_PROMPT_DATASET_PATH = f"{RAW_PROMPT_DATASET_PATH}_filtered_{PROMPT_MIN_DURATION_S}_{PROMPT_MAX_DURATION_S}s" # SHARED path + +# --- TTS Settings --- +TARGET_SAMPLE_RATE = 16000 # Desired *prompt* sample rate +MAX_TTS_RETRIES = 3 +RETRY_DELAY_SECONDS = 2 + +# --- Processing Settings --- +# TEST_SINGLE_SAMPLE = False # Not needed, shard logic handles subset +# MULTI_GPU_PROCESSING = False # Not using mp.spawn + +# ------------------------ +# 辅助函数 (Helper Functions) - Unchanged from previous multi-GPU capable script +# ------------------------ +def preprocess_text(text): + if not isinstance(text, str): return "" + text = unicodedata.normalize('NFKC', text) + text = re.sub(r'[—–―‐‑⁃﹣-]', ' ', text) + text = text.replace('\u00AD', '').replace('\u200B', '') + text = re.sub(r'\s+', ' ', text).strip() + if text and text[-1] not in ['.', '?', '!']: text += '.' + return text + +def filter_prompt_logic(example): + if PROMPT_AUDIO_DURATION_FIELD in example and isinstance(example[PROMPT_AUDIO_DURATION_FIELD], (int, float)): + duration = example[PROMPT_AUDIO_DURATION_FIELD]; return PROMPT_MIN_DURATION_S <= duration <= PROMPT_MAX_DURATION_S + else: + try: + audio_info = example['audio']; samplerate = audio_info['sampling_rate']; duration = len(audio_info['array']) / samplerate + return PROMPT_MIN_DURATION_S <= duration <= PROMPT_MAX_DURATION_S + except: return False + +def get_random_valid_prompt(filtered_prompt_dataset, target_sample_rate=TARGET_SAMPLE_RATE): + if not filtered_prompt_dataset or len(filtered_prompt_dataset) == 0: raise ValueError("Filtered prompt dataset empty!") + idx = random.randint(0, len(filtered_prompt_dataset) - 1) + try: sample = filtered_prompt_dataset[idx] + except IndexError: sample = filtered_prompt_dataset[0] + audio_info = sample['audio']; prompt_text = sample[PROMPT_TEXT_FIELD] + if isinstance(audio_info, dict) and 'array' in audio_info: waveform = torch.tensor(audio_info['array'], dtype=torch.float32); sr = audio_info['sampling_rate'] + elif isinstance(audio_info, str) or (isinstance(audio_info, dict) and 'path' in audio_info): + path = audio_info if isinstance(audio_info, str) else audio_info['path']; waveform, sr = torchaudio.load(path) + else: raise TypeError("Unknown prompt audio format") + if not prompt_text or waveform.numel() == 0: return get_random_valid_prompt(filtered_prompt_dataset, target_sample_rate) + if sr != target_sample_rate: + if waveform.dim() > 1 and waveform.shape[0] > 1: waveform = waveform.mean(dim=0, keepdim=True) + elif waveform.dim() == 1: waveform = waveform.unsqueeze(0) + resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sample_rate); waveform = resampler(waveform) + if waveform.dim()==1: waveform = waveform.unsqueeze(0) + elif waveform.shape[0] > 1 : waveform = waveform.mean(dim=0, keepdim=True) + return waveform.cpu(), prompt_text + +def text_to_audio(text_to_convert, cosyvoice, filtered_prompt_dataset, target_sample_rate, stream=False, max_retries=MAX_TTS_RETRIES): + cleaned_text = preprocess_text(text_to_convert) + if not cleaned_text: logging.warning(f"Empty text after cleaning: '{text_to_convert[:60]}...'"); return None + last_exception = None + for attempt in range(max_retries): + try: + prompt_speech, prompt_text = get_random_valid_prompt(filtered_prompt_dataset, target_sample_rate) + all_speech = [] + inference_generator = cosyvoice.inference_zero_shot(cleaned_text, prompt_text, prompt_speech, stream=stream) + for i, chunk in enumerate(inference_generator): + if 'tts_speech' in chunk and chunk['tts_speech'] is not None: all_speech.append(chunk['tts_speech']) + if not all_speech: raise ValueError(f"TTS produced no audio chunks. Cleaned: '{cleaned_text[:60]}...'") + combined_speech = torch.cat(all_speech, dim=-1); actual_sample_rate = cosyvoice.sample_rate + return {'audio_tensor': combined_speech, 'sample_rate': actual_sample_rate} + except Exception as e: + last_exception = e; logging.error(f"TTS Error Attempt {attempt + 1}: {e}", exc_info=False) + if torch.cuda.is_available(): torch.cuda.empty_cache() + if attempt < max_retries - 1: time.sleep(RETRY_DELAY_SECONDS * (1.5 ** attempt) + random.uniform(0.5, 1.5)) + else: logging.error(f"All TTS retries failed for: '{cleaned_text[:60]}...'") + return None + +# ----------------------------- +# --- Main Execution Logic ---- +# ----------------------------- +if __name__ == "__main__": + # --- Load or Create Filtered Prompt Dataset (Safe for concurrent runs, only first one creates) --- + # Add a small delay + check to mitigate potential race condition on creation + if not os.path.exists(FILTERED_PROMPT_DATASET_PATH): + time.sleep(random.uniform(0, 2)) # Small random delay + if not os.path.exists(FILTERED_PROMPT_DATASET_PATH): # Double check + logging.info(f"Filtered prompt dataset not found. Attempting creation...") + try: + prompt_dataset_raw = load_dataset(RAW_PROMPT_DATASET_PATH, split=PROMPT_DATASET_SPLIT) + if 'audio' in prompt_dataset_raw.features and not isinstance(prompt_dataset_raw.features['audio'], Audio): + prompt_dataset_raw = prompt_dataset_raw.cast_column("audio", Audio(decode=True)) + filtered_ds = prompt_dataset_raw.filter(filter_prompt_logic, num_proc=max(1, os.cpu_count() // 2)) + if len(filtered_ds) == 0: raise ValueError("No prompts left after filtering.") + cols = ['audio', PROMPT_TEXT_FIELD] + if PROMPT_AUDIO_DURATION_FIELD in prompt_dataset_raw.column_names: cols.append(PROMPT_AUDIO_DURATION_FIELD) + filtered_ds = filtered_ds.select_columns(cols) + filtered_ds.save_to_disk(FILTERED_PROMPT_DATASET_PATH) + logging.info(f"Filtered prompt dataset CREATED and saved to: {FILTERED_PROMPT_DATASET_PATH}") + except Exception as e: + logging.error(f"FATAL: Failed to create filtered prompt dataset: {e}", exc_info=True) + # If creation fails, other processes might also fail loading. Check logs. + sys.exit(1) + else: + logging.info(f"Filtered prompt dataset appeared while waiting. Proceeding.") + + try: + logging.info(f"Loading filtered prompt dataset from: {FILTERED_PROMPT_DATASET_PATH}") + filtered_prompt_dataset = load_from_disk(FILTERED_PROMPT_DATASET_PATH) + logging.info(f"Loaded {len(filtered_prompt_dataset)} filtered prompts.") + except Exception as e: + logging.error(f"FATAL: Failed to load filtered prompt dataset: {e}", exc_info=True) + sys.exit(1) + + + # --- Initialize TTS Model (on the assigned GPU 'cuda:0') --- + logging.info("Initializing CosyVoice model for this process...") + try: + # Model will initialize on cuda:0, which corresponds to the selected physical GPU + cosyvoice = CosyVoice2( + f'{COSYVOICE_PATH}/pretrained_models/CosyVoice2-0.5B', + load_jit=True, load_trt=False, fp16=False + ) + model_output_sr = cosyvoice.sample_rate + logging.info(f"CosyVoice initialized. Model output SR: {model_output_sr}") + except Exception as e: + logging.error(f"Error initializing CosyVoice2 model: {e}", exc_info=True) + sys.exit(1) + + + # --- Load Main Input Dataset --- + logging.info(f"Loading main input dataset from: {INPUT_DATASET_PATH}") + try: + input_dataset_full = load_from_disk(INPUT_DATASET_PATH) + dataset_size = len(input_dataset_full) + logging.info(f"Loaded main dataset with {dataset_size} examples.") + + # --- Add original index column --- + logging.info("Adding original index...") + def add_index(example, idx): example['original_index'] = idx; return example + input_dataset_with_indices = input_dataset_full.map(add_index, with_indices=True, num_proc=max(1, os.cpu_count() // 2)) + logging.info("Original index added.") + + # --- Shard the dataset for this specific run --- + logging.info(f"Selecting shard {PROCESSING_GPU_ID + 1} / {TOTAL_PROCESSING_NODES} for processing...") + dataset_shard = input_dataset_with_indices.shard( + num_shards=TOTAL_PROCESSING_NODES, + index=PROCESSING_GPU_ID, + contiguous=True # Potentially faster access + ) + shard_size = len(dataset_shard) + logging.info(f"This instance will process {shard_size} samples.") + + except Exception as e: + logging.error(f"Error loading or sharding main input dataset: {e}", exc_info=True) + sys.exit(1) + + # --- Define Shared Output Audio Directory --- + # All instances write to the same place + split_output_dir = os.path.join(TTS_OUTPUT_BASE_PATH, ASSUMED_INPUT_SPLIT_NAME) + split_audio_dir = os.path.join(split_output_dir, AUDIO_SUBFOLDER_NAME) + os.makedirs(split_audio_dir, exist_ok=True) # Ensure it exists + + # --- Process the Assigned Shard --- + logging.info(f"Starting TTS processing for shard {PROCESSING_GPU_ID + 1}...") + pbar = tqdm(total=shard_size, desc=f"GPU-{PROCESSING_GPU_ID} TTS", ncols=100) + for sample in dataset_shard: + try: + original_idx = sample['original_index'] + text_to_convert = sample.get(TEXT_FIELD_FOR_TTS) + + if not text_to_convert or not isinstance(text_to_convert, str) or not text_to_convert.strip(): + logging.warning(f"Skipping original index {original_idx}: missing/invalid text.") + pbar.update(1) + continue + + # Define audio path using original index -> SHARED audio dir + audio_filename = f"query_{original_idx}.wav" + absolute_audio_path = os.path.join(split_audio_dir, audio_filename) + + # Check if audio already exists (supports resuming any run) + if os.path.exists(absolute_audio_path): + logging.debug(f"Audio exists for original index {original_idx}, skipping.") + pbar.update(1) + continue + + # Perform TTS + tts_result = text_to_audio( + text_to_convert, + cosyvoice, + filtered_prompt_dataset, + TARGET_SAMPLE_RATE, # Prompt SR + stream=False + ) + + if tts_result is not None: + audio_tensor = tts_result['audio_tensor'] # GPU tensor + output_sample_rate = tts_result['sample_rate'] # Use model's actual SR + + try: + audio_tensor_cpu = audio_tensor.detach().cpu().to(torch.float32) + if audio_tensor_cpu.dim() == 1: audio_tensor_cpu = audio_tensor_cpu.unsqueeze(0) + elif audio_tensor_cpu.dim() > 2: audio_tensor_cpu = audio_tensor_cpu.view(1, -1) + + # Save to the SHARED audio directory + torchaudio.save(absolute_audio_path, audio_tensor_cpu, output_sample_rate) + logging.debug(f"Saved audio for original index {original_idx}") + + del audio_tensor, audio_tensor_cpu + if torch.cuda.is_available(): torch.cuda.empty_cache() + + except Exception as e: + logging.error(f"Failed to save audio for original index {original_idx}: {e}") + if 'audio_tensor' in locals(): del audio_tensor + if 'audio_tensor_cpu' in locals(): del audio_tensor_cpu + if torch.cuda.is_available(): torch.cuda.empty_cache() + else: + logging.warning(f"TTS failed for original index {original_idx} after retries.") + + except Exception as e: + logging.error(f"Unexpected error processing sample for original index {sample.get('original_index', 'UNKNOWN')}: {e}", exc_info=True) + if torch.cuda.is_available(): torch.cuda.empty_cache() + + finally: + pbar.update(1) + + pbar.close() + logging.info(f"--- Finished processing shard {PROCESSING_GPU_ID + 1} / {TOTAL_PROCESSING_NODES} ---") + + # --- End of Script --- + print("="*30) + print(f"Run for GPU {PROCESSING_GPU_ID} completed.") + print(f"Audio files (if successful) saved in: {split_audio_dir}") + print("IMPORTANT: Run the aggregation script AFTER all 4 runs are finished to create the final dataset object.") + print("="*30) \ No newline at end of file diff --git a/r1-a/dataset/gsm8k.py b/r1-a/dataset/gsm8k.py new file mode 100644 index 0000000000000000000000000000000000000000..d93b28d6d11d018e8996a3096dc56fdca03a7a41 --- /dev/null +++ b/r1-a/dataset/gsm8k.py @@ -0,0 +1,169 @@ +import os +import random +import torch +import torchaudio +from datasets import load_dataset, Dataset +import sys +from tqdm import tqdm + +sys.path.append('/root/autodl-tmp/CosyVoice') +from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2 +from cosyvoice.utils.file_utils import load_wav + +# 配置参数 +COMMON_VOICE_LANGUAGE = "en" +DATASET_NAME = "gsm8k" # 使用 gsm8k 数据集 +OUTPUT_DATASET_PATH = './gsm8k_with_audio' +SAMPLE_RATE = 16000 + +# --- 辅助函数 --- + +def get_random_prompt(common_voice_dataset, sample_rate=SAMPLE_RATE): + """ + 从 Common Voice 数据集中随机抽取一条语音及对应文本作为 prompt。 + """ + idx = random.randint(0, len(common_voice_dataset) - 1) + sample = common_voice_dataset.select([idx])[0] + audio = sample['audio'] + waveform = torch.tensor(audio['array'], dtype=torch.float32) + sr = audio['sampling_rate'] + if sr != sample_rate: + resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate) + waveform = resampler(waveform) + return waveform.unsqueeze(0), sample['raw_text'] + +def text_to_audio(query_text, cosyvoice, common_voice_dataset, stream=False): + """ + 利用 CosyVoice2 模型将输入文本转换为语音,采用随机 prompt 进行 zero-shot 推理。 + """ + try: + prompt_speech, prompt_text = get_random_prompt(common_voice_dataset, sample_rate=SAMPLE_RATE) + # 可选:保存 prompt.wav 进行调试 + # torchaudio.save('prompt.wav', prompt_speech, SAMPLE_RATE) + all_speech = [] + for i, j in enumerate(cosyvoice.inference_zero_shot( + query_text, + prompt_text, + prompt_speech, + stream=stream, + text_frontend=False + )): + all_speech.append(j['tts_speech']) + # 合并所有生成的语音片段为一个长 tensor + combined_speech = torch.cat(all_speech, dim=-1) + sample_rate_val = cosyvoice.sample_rate + return {'audio_tensor': combined_speech, 'sample_rate': sample_rate_val} + except Exception as e: + print(f"Error converting text to audio: {e}") + return None + +def process_example(example, cosyvoice, common_voice_dataset, sample_rate=SAMPLE_RATE): + """ + 针对 gsm8k 数据集中的单个样本进行 TTS 处理。 + 假设 gsm8k 数据集中的问题文本字段为 'question', + 答案字段为 'answer'。 + """ + query = example['question'] + audio_result = text_to_audio(query, cosyvoice, common_voice_dataset, stream=False) + if audio_result is not None: + # 返回生成的音频 tensor 及采样率 + return { + 'audio_tensor': audio_result['audio_tensor'], + 'sample_rate': audio_result['sample_rate'] + } + else: + return None + +# --- 数据加载与模型初始化 --- + +print("Loading Common Voice dataset...") +common_voice = load_dataset("facebook/voxpopuli", "en", split='train') +print(f"Total Common Voice {COMMON_VOICE_LANGUAGE} samples: {len(common_voice)}") + +print("Initializing CosyVoice2 model...") +cosyvoice = CosyVoice2( + '/root/autodl-tmp/CosyVoice/pretrained_models/CosyVoice2-0.5B', # 替换为实际的模型路径 + load_jit=True, + load_trt=False, + fp16=False +) + +print("Loading GSM8K dataset...") +dataset = load_dataset("openai/gsm8k", 'main') + +# 确保输出总目录存在 +os.makedirs(OUTPUT_DATASET_PATH, exist_ok=True) + +# --- 主处理循环 --- +# 对每个 split 分别处理,每个样本处理后保存 .wav 文件和记录最终数据集信息 +final_dataset_dict = {} # 用于保存最终数据集的每个 split + +for split_name, split_dataset in dataset.items(): + print(f"Processing split: {split_name} with {len(split_dataset)} examples") + split_output_dir = os.path.join(OUTPUT_DATASET_PATH, split_name) + os.makedirs(split_output_dir, exist_ok=True) + + # 用于断点续转的进度记录 + progress_file = os.path.join(split_output_dir, "progress.txt") + start_index = 0 + if os.path.exists(progress_file): + try: + with open(progress_file, "r") as f: + start_index = int(f.read().strip()) + print(f"Resuming split '{split_name}' from sample index {start_index}") + except Exception as e: + print(f"读取进度文件失败:{e}") + + final_samples = [] # 用于存储最终数据集样本信息 + for i in tqdm(range(len(split_dataset)), desc=f"Processing {split_name}"): + if i < start_index: + # 如果样本已处理,则加载对应的 wav 文件路径(假设之前已经生成)并加入最终数据集 + sample = split_dataset[i] + wav_path = os.path.join(split_output_dir, f"{i}.wav") + # 仅当文件存在时才加入最终数据集 + if os.path.exists(wav_path): + final_samples.append({ + "question_text": sample["question"], + "answer": sample["answer"], + "audio_filepath": wav_path + }) + continue + + sample = split_dataset[i] + # 处理 TTS 转换 + result = process_example(sample, cosyvoice, common_voice, sample_rate=SAMPLE_RATE) + + if result is not None: + # 确保 audio tensor shape 为 (channels, samples) + audio_tensor = result['audio_tensor'] + if audio_tensor.dim() == 1: + audio_tensor = audio_tensor.unsqueeze(0) + sample_rate_val = result['sample_rate'] + output_wav_path = os.path.join(split_output_dir, f"{i}.wav") + try: + torchaudio.save(output_wav_path, audio_tensor, sample_rate_val) + except Exception as e: + print(f"Failed to save wav for sample {i}: {e}") + continue + + # 将转换后的样本信息保存到最终数据集中 + final_samples.append({ + "question_text": sample["question"], + "answer": sample["answer"], + "audio_filepath": output_wav_path + }) + else: + print(f"Sample {i} processing failed, no audio generated.") + + # 更新进度记录 + with open(progress_file, "w") as f: + f.write(str(i + 1)) + + # 将当前 split 的最终数据集保存为 Hugging Face Dataset,并存盘 + final_dataset_obj = Dataset.from_list(final_samples) + final_dataset_save_path = os.path.join(split_output_dir, "final_dataset") + final_dataset_obj.save_to_disk(final_dataset_save_path) + print(f"Finished processing split: {split_name} with {len(final_samples)} final samples.") + final_dataset_dict[split_name] = final_dataset_obj + +print("所有分割处理完毕,最终数据集已保存。") diff --git a/r1-a/dataset/gsm8k_with_audio/test/299.wav b/r1-a/dataset/gsm8k_with_audio/test/299.wav new file mode 100644 index 0000000000000000000000000000000000000000..e1c0841c46973b0b1727507696466cc84efb59bd --- /dev/null +++ b/r1-a/dataset/gsm8k_with_audio/test/299.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:faf7e5efc632e11631a878b477523f9a010ad5e3bba9db05f3bc20cafadab95e +size 2277200 diff --git a/r1-a/dataset/gsm8k_with_audio/test/301.wav b/r1-a/dataset/gsm8k_with_audio/test/301.wav new file mode 100644 index 0000000000000000000000000000000000000000..81da4b53178761d39246e156c8a1d2fbfda775a9 --- /dev/null +++ b/r1-a/dataset/gsm8k_with_audio/test/301.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5010c552eb3c63b659ef8d71af813db7706b55c302f9c9e1c17a831c8bc896a4 +size 2346320 diff --git a/r1-a/dataset/gsm8k_with_audio/test/302.wav b/r1-a/dataset/gsm8k_with_audio/test/302.wav new file mode 100644 index 0000000000000000000000000000000000000000..619e82e0aabd8fcf705fdd5c3ef9389102ad4747 --- /dev/null +++ b/r1-a/dataset/gsm8k_with_audio/test/302.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:06e41bddcccc46feb3fd116d32c245ef65cf0434acda1313d020d286a947ddcd +size 917840 diff --git a/r1-a/dataset/gsm8k_with_audio/test/314.wav b/r1-a/dataset/gsm8k_with_audio/test/314.wav new file mode 100644 index 0000000000000000000000000000000000000000..19feb33f02b32f04dcc716ec47255ad6f5f13386 --- /dev/null +++ b/r1-a/dataset/gsm8k_with_audio/test/314.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:026282fcebd0c6ba5bef04291128bbc5dd82eadeb2c020faa0a53135477a7e19 +size 1332560 diff --git a/r1-a/dataset/gsm8k_with_audio/test/316.wav b/r1-a/dataset/gsm8k_with_audio/test/316.wav new file mode 100644 index 0000000000000000000000000000000000000000..2dabe4ba50708a2ab9c286bcc73df423d42aa6cc --- /dev/null +++ b/r1-a/dataset/gsm8k_with_audio/test/316.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:83a29e59fc73ac1eb3b49033da596e47141d5911ceae61effc736b4b307a17de +size 1505360 diff --git a/r1-a/dataset/gsm8k_with_audio/test/350.wav b/r1-a/dataset/gsm8k_with_audio/test/350.wav new file mode 100644 index 0000000000000000000000000000000000000000..6311a23dca7d8f4508530915c47612c11f4c24ec --- /dev/null +++ b/r1-a/dataset/gsm8k_with_audio/test/350.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f5fd7f8dbaf220cd8fe66a151c056e384e07b13cf76778983bbc605043e43d4d +size 1946960 diff --git a/r1-a/dataset/gsm8k_with_audio/test/358.wav b/r1-a/dataset/gsm8k_with_audio/test/358.wav new file mode 100644 index 0000000000000000000000000000000000000000..1dbc21aa21ae36d5c8c3f942007d03d24cc054fa --- /dev/null +++ b/r1-a/dataset/gsm8k_with_audio/test/358.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9a9416095b31c0d238bb1321e6108733973fa651cfcee1fca8f898d18470a89b +size 1363280 diff --git a/r1-a/dataset/gsm8k_with_audio/test/359.wav b/r1-a/dataset/gsm8k_with_audio/test/359.wav new file mode 100644 index 0000000000000000000000000000000000000000..2c08b33c02109b9494c8e78e04326cdbb3608f01 --- /dev/null +++ b/r1-a/dataset/gsm8k_with_audio/test/359.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c1ba8571451ad8a8bce5fdf235517e0aae650b7a42d9cc763e92adc7e9c949de +size 1532240 diff --git a/r1-a/dataset/gsm8k_with_audio/test/369.wav b/r1-a/dataset/gsm8k_with_audio/test/369.wav new file mode 100644 index 0000000000000000000000000000000000000000..5d609d6ab3f6287e1a3d033bbabb296600c8ecee --- /dev/null +++ b/r1-a/dataset/gsm8k_with_audio/test/369.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4365c6f3f261264849987b3a13a1d552c1004977c0afea259b9e29dad8fee08c +size 1866320 diff --git a/r1-a/dataset/gsm8k_with_audio/test/372.wav b/r1-a/dataset/gsm8k_with_audio/test/372.wav new file mode 100644 index 0000000000000000000000000000000000000000..903a0bac536619234fb55424e3cc889143f27beb --- /dev/null +++ b/r1-a/dataset/gsm8k_with_audio/test/372.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a22d041239c1485f42ac79c20fd1ae2a9174fef598e36030c516c882663eae3c +size 1251920 diff --git a/r1-a/dataset/gsm8k_with_audio/test/376.wav b/r1-a/dataset/gsm8k_with_audio/test/376.wav new file mode 100644 index 0000000000000000000000000000000000000000..fdde417199d3063e5cf1a460a3d409016e0d472c --- /dev/null +++ b/r1-a/dataset/gsm8k_with_audio/test/376.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:65160b2855b4f25f29a7e0f4b5da970cba1e60ae1f1e637e5adfd7495c896f79 +size 1198160 diff --git a/r1-a/dataset/gsm8k_with_audio/test/385.wav b/r1-a/dataset/gsm8k_with_audio/test/385.wav new file mode 100644 index 0000000000000000000000000000000000000000..a841f7fd07b357d56af76b7755a50605b4e817e1 --- /dev/null +++ b/r1-a/dataset/gsm8k_with_audio/test/385.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3bbcea803a885d3d5cc8b17f23837a1820f119f9564006023f8781b96a2114f4 +size 2165840 diff --git a/r1-a/dataset/gsm8k_with_audio/test/394.wav b/r1-a/dataset/gsm8k_with_audio/test/394.wav new file mode 100644 index 0000000000000000000000000000000000000000..c0e6654786560f66e5f5d9124090b1bbde619c5b --- /dev/null +++ b/r1-a/dataset/gsm8k_with_audio/test/394.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ec4937195a9d64a026cd352a09943a56ee64344fedda73cb792d4d12b466fe3f +size 1344080 diff --git a/r1-a/dataset/gsm8k_with_audio/test/395.wav b/r1-a/dataset/gsm8k_with_audio/test/395.wav new file mode 100644 index 0000000000000000000000000000000000000000..7f278c778ec57eb573ca94b4d8ac8da0f7024e27 --- /dev/null +++ b/r1-a/dataset/gsm8k_with_audio/test/395.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7f84925ad72d43ff5b7bf24aa71472f08bc0ff5b43a09ed92929848451401ff0 +size 1804880 diff --git a/r1-a/dataset/gsm8k_with_audio/test/397.wav b/r1-a/dataset/gsm8k_with_audio/test/397.wav new file mode 100644 index 0000000000000000000000000000000000000000..720dc364d8665d3fa5ebf9cfa4211f38d62ff708 --- /dev/null +++ b/r1-a/dataset/gsm8k_with_audio/test/397.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cc7c4027e38fd2090b026f8112151fd564d5a80df9ff230dbac69cb082f60c59 +size 1056080 diff --git a/r1-a/dataset/gsm8k_with_audio/test/400.wav b/r1-a/dataset/gsm8k_with_audio/test/400.wav new file mode 100644 index 0000000000000000000000000000000000000000..bb9bcb1145c96aa69aa2fe832d5460b0b900c127 --- /dev/null +++ b/r1-a/dataset/gsm8k_with_audio/test/400.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cf0dace278147001a9b14f5701f5d206039b26438e50f19bcaf6015b46dab78e +size 867920 diff --git a/r1-a/dataset/gsm8k_with_audio/test/401.wav b/r1-a/dataset/gsm8k_with_audio/test/401.wav new file mode 100644 index 0000000000000000000000000000000000000000..07eea5a3efcdc6420d7da3d3306d71268cedbfd7 --- /dev/null +++ b/r1-a/dataset/gsm8k_with_audio/test/401.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a494472207f4c7f2d067dfa1ce3c19d18702b8fca5385aa2f23b7b7a9a869767 +size 940880 diff --git a/r1-a/dataset/gsm8k_with_audio/test/447.wav b/r1-a/dataset/gsm8k_with_audio/test/447.wav new file mode 100644 index 0000000000000000000000000000000000000000..31a1e2c8f30b74dd1e6f00a9d68c31da3fd772d8 --- /dev/null +++ b/r1-a/dataset/gsm8k_with_audio/test/447.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6f08dd8d815e539e2d6b2ea2c06320784e3f6b3f862e3d50e20a2c11ac38785b +size 1586000 diff --git a/r1-a/dataset/gsm8k_with_audio/test/45.wav b/r1-a/dataset/gsm8k_with_audio/test/45.wav new file mode 100644 index 0000000000000000000000000000000000000000..347afe1d657676014cc225f417667b6715b0ac1c --- /dev/null +++ b/r1-a/dataset/gsm8k_with_audio/test/45.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9a65adddc9e7cb0a684486badfd68a41610e2acd64e30b8e8a7da76607586092 +size 2112080 diff --git a/r1-a/dataset/gsm8k_with_audio/test/450.wav b/r1-a/dataset/gsm8k_with_audio/test/450.wav new file mode 100644 index 0000000000000000000000000000000000000000..5955c7a821f2802df3a0f01e99ee775f95a076eb --- /dev/null +++ b/r1-a/dataset/gsm8k_with_audio/test/450.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e42276c81767cb468b27f095f71e5a2765272a753ee9d69cb399998561ee3348 +size 2104400 diff --git a/r1-a/dataset/gsm8k_with_audio/test/454.wav b/r1-a/dataset/gsm8k_with_audio/test/454.wav new file mode 100644 index 0000000000000000000000000000000000000000..816385f3138a5c97f8d61e3a01f918483fa55425 --- /dev/null +++ b/r1-a/dataset/gsm8k_with_audio/test/454.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c94820c5d8bdd41bdaf461ca950650950c0e44fa2427cb86c9ae59bfc8f781fc +size 599120 diff --git a/r1-a/dataset/gsm8k_with_audio/test/457.wav b/r1-a/dataset/gsm8k_with_audio/test/457.wav new file mode 100644 index 0000000000000000000000000000000000000000..9e239407f637dfc2e9c6bed82aa5d911acf6eaf2 --- /dev/null +++ b/r1-a/dataset/gsm8k_with_audio/test/457.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3e3aaab17db00c2275f1ac1ebb8d2c710f4e7393488ca0a4fcdb45d57b6da2d9 +size 1774160 diff --git a/r1-a/dataset/gsm8k_with_audio/test/458.wav b/r1-a/dataset/gsm8k_with_audio/test/458.wav new file mode 100644 index 0000000000000000000000000000000000000000..e2f0d1edf384a556629633cca1110143127d04ed --- /dev/null +++ b/r1-a/dataset/gsm8k_with_audio/test/458.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:813d8661f38ffba540960d4c50dd877714a196acd0d9fa1d2f9a9285ab35b40e +size 1082960 diff --git a/r1-a/dataset/gsm8k_with_audio/test/459.wav b/r1-a/dataset/gsm8k_with_audio/test/459.wav new file mode 100644 index 0000000000000000000000000000000000000000..95e0cda46856a37d65d0c2d4b0e10408db2b841b --- /dev/null +++ b/r1-a/dataset/gsm8k_with_audio/test/459.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:48f55b9dbcda81e8e98ac58c7e428de7ed472f79d4dbbf763d8f6232d53b3e00 +size 2945360 diff --git a/r1-a/dataset/gsm8k_with_audio/test/463.wav b/r1-a/dataset/gsm8k_with_audio/test/463.wav new file mode 100644 index 0000000000000000000000000000000000000000..381e779f8b84a99813e1f287ecb083004c1db8fc --- /dev/null +++ b/r1-a/dataset/gsm8k_with_audio/test/463.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3c0710e41b29894e6e2d4519a60c00ccdba41584f0a8fa1e9c4800bf69d5b63f +size 1298000 diff --git a/r1-a/dataset/gsm8k_with_audio/test/465.wav b/r1-a/dataset/gsm8k_with_audio/test/465.wav new file mode 100644 index 0000000000000000000000000000000000000000..061455fb8289431ba36b676d5760d5e009413c56 --- /dev/null +++ b/r1-a/dataset/gsm8k_with_audio/test/465.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ccdf5bacc8e5c5c24f086471ad15811f56e355f2814963aaa2f2c94ce916da05 +size 1747280 diff --git a/r1-a/dataset/gsm8k_with_audio/test/515.wav b/r1-a/dataset/gsm8k_with_audio/test/515.wav new file mode 100644 index 0000000000000000000000000000000000000000..19732bb65115a3ca38744a95b1b311efcbe38af8 --- /dev/null +++ b/r1-a/dataset/gsm8k_with_audio/test/515.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:48ea6967d2fa10e62b5e74b695650c9cf4f5a06d4a232937627e75f35b93e526 +size 1278800 diff --git a/r1-a/dataset/gsm8k_with_audio/test/877.wav b/r1-a/dataset/gsm8k_with_audio/test/877.wav new file mode 100644 index 0000000000000000000000000000000000000000..ee076fa3655df054b31ac2c01d07aa916d179521 Binary files /dev/null and b/r1-a/dataset/gsm8k_with_audio/test/877.wav differ diff --git a/r1-a/dataset/gsm8k_with_audio/test/964.wav b/r1-a/dataset/gsm8k_with_audio/test/964.wav new file mode 100644 index 0000000000000000000000000000000000000000..03829963e917db6f4e17683dcf398a5585e6803e Binary files /dev/null and b/r1-a/dataset/gsm8k_with_audio/test/964.wav differ diff --git a/r1-a/dataset/gsm8k_with_audio/test/final_dataset/dataset_info.json b/r1-a/dataset/gsm8k_with_audio/test/final_dataset/dataset_info.json new file mode 100644 index 0000000000000000000000000000000000000000..096cc36d574f55059da068a7c17b14d3c11f1ccf --- /dev/null +++ b/r1-a/dataset/gsm8k_with_audio/test/final_dataset/dataset_info.json @@ -0,0 +1,20 @@ +{ + "citation": "", + "description": "", + "features": { + "question_text": { + "dtype": "string", + "_type": "Value" + }, + "answer": { + "dtype": "string", + "_type": "Value" + }, + "audio_filepath": { + "dtype": "string", + "_type": "Value" + } + }, + "homepage": "", + "license": "" +} \ No newline at end of file diff --git a/r1-a/dataset/gsm8k_with_audio/test/final_dataset/state.json b/r1-a/dataset/gsm8k_with_audio/test/final_dataset/state.json new file mode 100644 index 0000000000000000000000000000000000000000..c4c3c0219551b4ec3c8ff7c2b1b94d173db6496b --- /dev/null +++ b/r1-a/dataset/gsm8k_with_audio/test/final_dataset/state.json @@ -0,0 +1,13 @@ +{ + "_data_files": [ + { + "filename": "data-00000-of-00001.arrow" + } + ], + "_fingerprint": "a037f9a8bcdfc025", + "_format_columns": null, + "_format_kwargs": {}, + "_format_type": null, + "_output_all_columns": false, + "_split": null +} \ No newline at end of file diff --git a/r1-a/dataset/gsm8k_with_audio/test/progress.txt b/r1-a/dataset/gsm8k_with_audio/test/progress.txt new file mode 100644 index 0000000000000000000000000000000000000000..3b6be60c9048a44c5133158efd05515a5c2a01ee --- /dev/null +++ b/r1-a/dataset/gsm8k_with_audio/test/progress.txt @@ -0,0 +1 @@ +1319 \ No newline at end of file diff --git a/r1-a/dataset/pkusafe.py b/r1-a/dataset/pkusafe.py new file mode 100644 index 0000000000000000000000000000000000000000..5bcf3fc6c1d7c2365fffa391eeb8c58abaa8bbca --- /dev/null +++ b/r1-a/dataset/pkusafe.py @@ -0,0 +1,171 @@ +import json +from collections import defaultdict +from datasets import load_dataset, Dataset # 确保导入 Dataset +from tqdm.auto import tqdm +import traceback +import os # 需要导入 os 来创建目录 + +# --- 配置参数 --- +DATASET_NAME = "PKU-Alignment/PKU-SafeRLHF" +SPLIT_TO_PROCESS = "train" # 或者 'test' 等 +OUTPUT_DATASET_DIR = "pku_saferlhf_filtered_unsafe_diverse_hf" # 输出目录名 + +# --- 脚本主逻辑 (与之前版本相同) --- + +def get_true_harm_categories(harm_category_dict): + """从 harm_category 字典中提取值为 True 的键(类别名称)""" + if not isinstance(harm_category_dict, dict): + return [] + return [category for category, is_present in harm_category_dict.items() if is_present] + +def filter_pku_saferlhf_detailed(dataset_name: str, split: str): + """ + 加载、过滤 PKU-SafeRLHF 数据集,确保包含不安全回答,提取涉及的 harm category 名称, + 并尽可能覆盖所有原始存在的 harm category 名称。 + """ + print(f"加载数据集: {dataset_name}, split: {split}...") + try: + dataset = load_dataset(dataset_name, split=split) + print("数据集加载完成。") + except Exception as e: + print(f"错误:无法加载数据集 {dataset_name} (split: {split})。请检查名称和可用 split。") + print(f"错误详情: {e}") + return None, set() + + # --- 步骤 1: 识别原始数据集中所有出现过的 harm_category 名称 --- + print("正在识别所有原始 harm_category 名称...") + all_harm_category_names_in_split = set() + scan_errors = 0 + for example in tqdm(dataset, desc="扫描 harm_category"): + try: + for harm_dict_key in ["response_0_harm_category", "response_1_harm_category"]: + harm_dict = example.get(harm_dict_key) + if isinstance(harm_dict, dict): + true_categories = get_true_harm_categories(harm_dict) + all_harm_category_names_in_split.update(true_categories) + except Exception as e: + scan_errors += 1 + continue + + if scan_errors > 0: + print(f"警告:在扫描 harm_category 阶段,有 {scan_errors} 个样本处理出错。") + if not all_harm_category_names_in_split: + print("错误:未能从数据集中提取任何 harm_category 名称。请检查数据集结构或内容。") + return None, set() + print(f"原始数据集中找到 {len(all_harm_category_names_in_split)} 个唯一的 harm_category 名称。") + + # --- 步骤 2: 筛选至少有一个不安全回答的样本,并提取涉及的类别 --- + print("正在筛选包含不安全回答的样本并提取涉及的类别...") + filtered_data_list = [] # 改名为 list 以区分最终的 Dataset 对象 + covered_categories_in_filtered = set() + processed_count = 0 + filtered_count = 0 + filter_errors = 0 + + for example in tqdm(dataset, desc="筛选不安全样本"): + processed_count += 1 + try: + prompt = example.get("prompt") + response_0 = example.get("response_0") + response_1 = example.get("response_1") + is_safe_0 = example.get("is_response_0_safe") + is_safe_1 = example.get("is_response_1_safe") + response_0_harm_dict = example.get("response_0_harm_category") + response_1_harm_dict = example.get("response_1_harm_category") + + if not all([prompt, response_0, response_1, + is_safe_0 is not None, is_safe_1 is not None]): + continue + + if not is_safe_0 or not is_safe_1: + involved_harm_categories = set() + + if not is_safe_0: + categories_0 = get_true_harm_categories(response_0_harm_dict) + involved_harm_categories.update(categories_0) + + if not is_safe_1: + categories_1 = get_true_harm_categories(response_1_harm_dict) + involved_harm_categories.update(categories_1) + + # (可选过滤逻辑) + # if not involved_harm_categories: + # continue + + filtered_sample = { + "prompt": prompt, + "response_0": response_0, + "response_1": response_1, + "is_safe_0": is_safe_0, + "is_safe_1": is_safe_1, + # **** 注意:Dataset 对象对于 list of strings 的支持更好 **** + "involved_harm_categories": sorted(list(involved_harm_categories)), + "better_response_id": example.get("better_response_id"), + "safer_response_id": example.get("safer_response_id"), + # 可以根据需要添加其他字段 + } + filtered_data_list.append(filtered_sample) + covered_categories_in_filtered.update(involved_harm_categories) + filtered_count += 1 + + except Exception as e: + filter_errors += 1 + continue + + if filter_errors > 0: + print(f"警告:在筛选阶段,有 {filter_errors} 个样本处理出错。") + print(f"筛选完成。共处理 {processed_count} 个样本,筛选出 {filtered_count} 个符合条件的样本。") + + # --- 步骤 3: 检查 harm_category 覆盖情况 --- + missing_categories = all_harm_category_names_in_split - covered_categories_in_filtered + if missing_categories: + print(f"\n警告:以下 harm_category 名称存在于原始数据集中,但在筛选出的不安全样本中未能找到对应的类别: {missing_categories}") + else: + print("\n好消息!所有原始数据集中存在的 harm_category 名称都已在筛选后的数据中得到覆盖。") + print(f"最终数据集包含 {len(filtered_data_list)} 个样本,覆盖 {len(covered_categories_in_filtered)} 个 harm_category 名称。") + + return filtered_data_list, covered_categories_in_filtered # 返回 list 和 set + +# --- 主程序 --- +if __name__ == "__main__": + # 执行过滤 + filtered_list, final_categories = filter_pku_saferlhf_detailed(DATASET_NAME, SPLIT_TO_PROCESS) + + if filtered_list: # 检查返回的列表是否非空 + print(f"\n将 {len(filtered_list)} 条过滤后的数据保存为 Hugging Face Dataset 格式...") + print(f"目标目录: {OUTPUT_DATASET_DIR}") + + try: + # 1. 将 list of dicts 转换为 Dataset 对象 + # Hugging Face 会自动推断列类型。 involved_harm_categories 是 list of strings。 + hf_dataset = Dataset.from_list(filtered_list) + + # 2. 保存到磁盘 + if not os.path.exists(OUTPUT_DATASET_DIR): + os.makedirs(OUTPUT_DATASET_DIR) + print(f"已创建目录: {OUTPUT_DATASET_DIR}") + + hf_dataset.save_to_disk(OUTPUT_DATASET_DIR) + print(f"\n数据集成功保存到目录: {OUTPUT_DATASET_DIR}") + print(f"你可以使用以下代码加载它:") + print(f"from datasets import load_from_disk") + print(f"loaded_dataset = load_from_disk('{OUTPUT_DATASET_DIR}')") + print(f"\n最终数据集覆盖的 harm_category 名称: {final_categories}") + + # 打印一些样本看看 (从 Dataset 对象加载) + print("\n部分样本预览 (从保存的 Dataset 加载):") + loaded_dataset = Dataset.load_from_disk(OUTPUT_DATASET_DIR) # 加载回来验证 + for i in range(min(5, len(loaded_dataset))): + sample = loaded_dataset[i] + print(f"--- 样本 {i+1} ---") + print(f"Prompt: {sample['prompt'][:150]}...") + print(f"Response 0 (is_safe={sample['is_safe_0']}): {sample['response_0'][:100]}...") + print(f"Response 1 (is_safe={sample['is_safe_1']}): {sample['response_1'][:100]}...") + print(f"Involved Harm Categories: {sample['involved_harm_categories']}") + + except Exception as e: + print(f"\n错误:保存 Hugging Face Dataset 时出错: {e}") + traceback.print_exc() # 打印详细错误信息 + + else: + print("\n未能筛选出任何符合条件的样本,或在加载/处理数据时发生严重错误。未保存任何内容。") \ No newline at end of file diff --git a/r1-a/dataset/pkusafe_tts.py b/r1-a/dataset/pkusafe_tts.py new file mode 100644 index 0000000000000000000000000000000000000000..31b6165b58615ef71c3576731a36946c285b08ba --- /dev/null +++ b/r1-a/dataset/pkusafe_tts.py @@ -0,0 +1,279 @@ +import os +import random +import torch +import torchaudio +# Import load_from_disk to load the dataset saved by your first script +from datasets import load_dataset, Dataset, load_from_disk +import sys +from tqdm import tqdm +import time # Import time for potential delays between retries + +sys.path.append('/root/autodl-tmp/CosyVoice') # Make sure this path is correct for your environment +from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2 +from cosyvoice.utils.file_utils import load_wav + +# ------------------------ +# 配置参数 +# ------------------------ +COMMON_VOICE_LANGUAGE = "en" +# --- Path to the pre-filtered dataset saved by your FIRST script --- +FILTERED_DATASET_PATH = "pku_saferlhf_filtered_unsafe_diverse_hf" # <-- IMPORTANT: Make sure this matches the OUTPUT_DATASET_DIR from your first script +# --- Output path for THIS TTS script --- +OUTPUT_DATASET_PATH = './pku_saferlhf_filtered_with_audio' # <-- New output path for the dataset with audio +SAMPLE_RATE = 16000 +MAX_TTS_RETRIES = 3 # Maximum number of TTS attempts per query +RETRY_DELAY_SECONDS = 2 # Optional delay between retries + +# ------------------------ +# 辅助函数 (Identical to the previous version with retry logic) +# ------------------------ +def get_random_prompt(common_voice_dataset, sample_rate=SAMPLE_RATE): + """ + 从 VoxPopuli (替代原 Common Voice) 数据集中随机抽取一条语音及对应文本作为 prompt。 + """ + idx = random.randint(0, len(common_voice_dataset) - 1) + sample = common_voice_dataset.select([idx])[0] + audio = sample['audio'] + waveform = torch.tensor(audio['array'], dtype=torch.float32) + sr = audio['sampling_rate'] + if sr != sample_rate: + if waveform.dim() > 1: + waveform = waveform.mean(dim=0) + resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate) + waveform = resampler(waveform) + if waveform.dim() == 1: + waveform = waveform.unsqueeze(0) + if waveform.numel() == 0 or not sample['raw_text']: + print("Warning: Got an empty prompt, trying again...") + return get_random_prompt(common_voice_dataset, sample_rate) + return waveform, sample['raw_text'] + +def text_to_audio(query_text, cosyvoice, common_voice_dataset, stream=False, max_retries=MAX_TTS_RETRIES): + """ + 利用 CosyVoice2 模型将输入文本转换为语音,采用随机 prompt 进行零样本推理。 + Includes retry logic on failure. + """ + last_exception = None + for attempt in range(max_retries): + try: + prompt_speech, prompt_text = get_random_prompt(common_voice_dataset, sample_rate=SAMPLE_RATE) + + all_speech = [] + inference_generator = cosyvoice.inference_zero_shot( + query_text, + prompt_text, + prompt_speech, + stream=stream, + text_frontend=False + ) + for i, chunk in enumerate(inference_generator): + if 'tts_speech' in chunk and chunk['tts_speech'] is not None: + all_speech.append(chunk['tts_speech']) + else: + print(f"Warning: Chunk {i} missing 'tts_speech' for text '{query_text[:60]}...'") + + if not all_speech: + raise ValueError("TTS inference finished but produced no audio chunks.") + + combined_speech = torch.cat(all_speech, dim=-1) + sample_rate_val = cosyvoice.sample_rate + + return { + 'audio_tensor': combined_speech, + 'sample_rate': sample_rate_val + } + except Exception as e: + last_exception = e + print(f"Error converting text to audio on attempt {attempt + 1}/{max_retries}: {e}") + print(f"Text: '{query_text[:100]}...'") + print(f"Prompt Text: '{prompt_text[:100]}...'") + if attempt < max_retries - 1: + print(f"Retrying with a different prompt in {RETRY_DELAY_SECONDS}s...") + time.sleep(RETRY_DELAY_SECONDS) + else: + print(f"All {max_retries} TTS attempts failed.") + + print(f"Failed to generate audio for text after {max_retries} attempts: '{query_text[:60]}...'") + print(f"Last error: {last_exception}") + return None + +def process_example(example, cosyvoice, common_voice_dataset, sample_rate=SAMPLE_RATE): + """ + 针对从磁盘加载的 PKU-SafeRLHF 过滤后数据集中的单个样本进行 TTS 处理。 + Processes example['prompt']. <--- Changed from 'query'/'question' + """ + # --- Target the 'prompt' field from the filtered PKU-SafeRLHF dataset --- + query = example.get('prompt') # <--- Use 'prompt' field + if not query or not isinstance(query, str) or query.strip() == "": + print(f"Warning: Skipping example due to missing or empty 'prompt' field: {example.keys()}") # Log keys if prompt is missing + return None + + # --- Use the text_to_audio function with retry logic --- + audio_result = text_to_audio(query, cosyvoice, common_voice_dataset, stream=False) + + if audio_result is not None: + audio_tensor = audio_result['audio_tensor'] + if audio_tensor.dim() == 1: + audio_tensor = audio_tensor.unsqueeze(0) + elif audio_tensor.dim() > 2: + print(f"Warning: Generated audio tensor has unexpected shape {audio_tensor.shape}. Attempting to flatten.") + audio_tensor = audio_tensor.view(1, -1) + + if audio_tensor.numel() == 0: + print(f"Warning: Generated audio tensor is empty for prompt: '{query[:60]}...'") + return None + + return { + 'audio_tensor': audio_tensor, + 'sample_rate': audio_result['sample_rate'] + } + else: + return None + +# ------------------------ +# 数据加载与模型初始化 +# ------------------------ +print("Loading VoxPopuli (as Common Voice) dataset for prompts...") +try: + common_voice = load_dataset("facebook/voxpopuli", "en", split='train') + print(f"Total VoxPopuli {COMMON_VOICE_LANGUAGE} samples: {len(common_voice)}") + if len(common_voice) == 0: + raise ValueError("VoxPopuli dataset loaded but contains no samples.") +except Exception as e: + print(f"Error loading VoxPopuli dataset: {e}") + sys.exit(1) + + +print("Initializing CosyVoice2 model...") +try: + cosyvoice = CosyVoice2( + '/root/autodl-tmp/CosyVoice/pretrained_models/CosyVoice2-0.5B', # Verify this path is correct + load_jit=True, + load_trt=False, + fp16=False + ) +except Exception as e: + print(f"Error initializing CosyVoice2 model: {e}") + sys.exit(1) + +print(f"Loading pre-filtered PKU-SafeRLHF dataset from disk: {FILTERED_DATASET_PATH}") +try: + # --- Load the dataset saved by your first script --- + filtered_dataset = load_from_disk(FILTERED_DATASET_PATH) + if not filtered_dataset: + raise ValueError(f"Dataset loaded from '{FILTERED_DATASET_PATH}' is empty or invalid.") + print(f"Successfully loaded dataset with {len(filtered_dataset)} examples.") + # --- Assume the loaded dataset corresponds to a single split (e.g., 'train') --- + # Wrap it in a dictionary to match the structure expected by the loop below + dataset_dict = {"train": filtered_dataset} # Use "train" as the key, matching the split processed by the filter script + # Alternatively, if you know the split name used in the filter script was different, use that name. +except FileNotFoundError: + print(f"Error: Pre-filtered dataset not found at '{FILTERED_DATASET_PATH}'.") + print("Please ensure the first script ran successfully and saved the data to the correct location.") + sys.exit(1) +except Exception as e: + print(f"Error loading pre-filtered dataset from '{FILTERED_DATASET_PATH}': {e}") + sys.exit(1) + +# 创建输出目录 +os.makedirs(OUTPUT_DATASET_PATH, exist_ok=True) + +# ------------------------ +# 主处理循环 +# ------------------------ +final_dataset_dict = {} # 存放各 split 最终处理后的数据 + +# Iterate through the splits defined in dataset_dict (should just be 'train' in this case) +for split_name, split_dataset in dataset_dict.items(): + print(f"Processing split: {split_name} with {len(split_dataset)} examples") + split_output_dir = os.path.join(OUTPUT_DATASET_PATH, split_name) + os.makedirs(split_output_dir, exist_ok=True) + + # 用于断点续跑的进度记录 + progress_file = os.path.join(split_output_dir, "progress.txt") + start_index = 0 + if os.path.exists(progress_file): + try: + with open(progress_file, "r") as f: + content = f.read().strip() + if content: + start_index = int(content) + print(f"Resuming split '{split_name}' from sample index {start_index}") + else: + print(f"Progress file '{progress_file}' is empty, starting from index 0.") + start_index = 0 + except ValueError: + print(f"Could not parse integer from progress file '{progress_file}'. Starting from index 0.") + start_index = 0 + except Exception as e: + print(f"Error reading progress file '{progress_file}': {e}. Starting from index 0.") + start_index = 0 + + final_samples = [] + + # 遍历处理每条样本 + pbar = tqdm(range(start_index, len(split_dataset)), desc=f"Processing {split_name}", initial=start_index, total=len(split_dataset)) + for i in pbar: + sample = split_dataset[i] + output_wav_path = os.path.join(split_output_dir, f"{i}.wav") + + if os.path.exists(output_wav_path): + sample_dict = {k: sample[k] for k in sample.keys()} + sample_dict["audio_filepath"] = output_wav_path + final_samples.append(sample_dict) + with open(progress_file, "w") as f: + f.write(str(i + 1)) + continue + + # --- Perform TTS on the 'prompt' field --- + result = process_example(sample, cosyvoice, common_voice, sample_rate=SAMPLE_RATE) + + if result is not None: + audio_tensor = result['audio_tensor'] + sample_rate_val = result['sample_rate'] + + try: + audio_tensor_save = audio_tensor.detach().cpu().to(torch.float32) + if audio_tensor_save.dim() == 1: + audio_tensor_save = audio_tensor_save.unsqueeze(0) + elif audio_tensor_save.dim() > 2: + audio_tensor_save = audio_tensor_save.view(1, -1) + + torchaudio.save(output_wav_path, audio_tensor_save, sample_rate_val) + + # Preserve all original fields from the filtered dataset + add audio path + sample_dict = {k: sample[k] for k in sample.keys()} + sample_dict["audio_filepath"] = output_wav_path + final_samples.append(sample_dict) + + except Exception as e: + print(f"Failed to save wav for sample {i} at {output_wav_path}: {e}") + else: + print(f"Sample {i} processing failed after retries (Prompt: '{sample.get('prompt', 'N/A')[:60]}...'), no audio generated.") + + # Update progress file after processing each sample + with open(progress_file, "w") as f: + f.write(str(i + 1)) + + # Generate Hugging Face Dataset from the collected successful samples and save + if final_samples: + final_dataset_obj = Dataset.from_list(final_samples) + final_dataset_save_path = os.path.join(split_output_dir, "final_dataset") + try: + print(f"Saving final dataset for split '{split_name}' to {final_dataset_save_path}...") + final_dataset_obj.save_to_disk(final_dataset_save_path) + print(f"Finished processing split: {split_name}. Saved {len(final_samples)} samples with audio paths.") + final_dataset_dict[split_name] = final_dataset_obj + except Exception as e: + print(f"Error saving final dataset for split '{split_name}' to disk: {e}") + else: + print(f"Finished processing split: {split_name}. No samples were successfully processed or saved.") + + +print("="*30) +if final_dataset_dict: + print(f"All specified splits processed. Final datasets saved in respective subdirectories within '{OUTPUT_DATASET_PATH}'.") + print(f"Processed splits: {list(final_dataset_dict.keys())}") +else: + print(f"Processing finished, but no final datasets were generated or saved in '{OUTPUT_DATASET_PATH}'. Check logs for errors.") +print("="*30) \ No newline at end of file diff --git a/r1-a/dataset/retry_rewrite.py b/r1-a/dataset/retry_rewrite.py new file mode 100644 index 0000000000000000000000000000000000000000..c9d3fc72c4ce8fa1c4c8f8d99f56e6f332633292 --- /dev/null +++ b/r1-a/dataset/retry_rewrite.py @@ -0,0 +1,442 @@ +import os +import http.client +import json +import time +import random +# Import necessary types from datasets +from datasets import load_dataset, Dataset, DatasetDict, Features, Value, Sequence +from tqdm.auto import tqdm +import sys +import logging +# Removed unused imports (like socket, already used by http.client indirectly) +# import socket # Not directly needed now +import concurrent.futures +from concurrent.futures import ThreadPoolExecutor +import shutil # Needed for atomic directory removal + +# --- Configuration --- +DATASET_NAME = "virtuoussy/Multi-subject-RLVR" # Or the original source if needed +DATASET_SPLIT = "train" +API_HOST = "api2.aigcbest.top" +API_PATH = "/v1/chat/completions" +LLM_MODEL = "gpt-4.1-mini" +API_KEY = os.environ.get('AIGCBEST_API_KEY', "sk-U15cDXxI0bboL6iH4Hymzl30ws6oWzazWe1Ndwq9QtiPUEgI") +if not API_KEY or API_KEY == "YOUR_API_KEY_HERE": + print("API Key is not set correctly. Please set the AIGCBEST_API_KEY environment variable or replace the placeholder.") + sys.exit(1) + +OUTPUT_DIR = f"./{DATASET_NAME.split('/')[-1]}_rephrased" +# Path to the existing, potentially incomplete, processed dataset (LOAD ONLY) +PROCESSED_DATA_PATH = os.path.join(OUTPUT_DIR, f"{DATASET_SPLIT}_processed") +# Path where intermediate and final results will be saved (SAVE ONLY) +FINAL_OUTPUT_PATH = os.path.join(OUTPUT_DIR, f"{DATASET_SPLIT}_processed_final") + +MAX_WORKERS = 20 # Adjust based on your system and API rate limits +REQUEST_DELAY_SECONDS = 0.15 # Base delay between requests +MAX_RETRIES = 3 # Max retries for each API call +SAVE_INTERVAL = 2000 # <<<--- How often to save progress (in number of processed items) + +# Setup logging +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logging.getLogger("datasets").setLevel(logging.WARNING) +logging.getLogger("huggingface_hub").setLevel(logging.WARNING) +logging.getLogger("filelock").setLevel(logging.WARNING) # Quiet down filelock warnings during save + +# --- LLM API Function (call_llm_api) --- +# (No changes needed here, keep the robust version) +def call_llm_api(original_question, api_key, host, path, model, retries=MAX_RETRIES): + system_prompt = ( + "You are an expert linguist specializing in converting structured prompts or " + "fill-in-the-blank problems into natural, spoken-language questions suitable for " + "text-to-speech (TTS). Your goal is to make the question sound like how a person " + "would naturally ask it. " + "If the input is a fill-in-the-blank problem (e.g., contains '-----'), " + "rephrase it as a direct question asking for the missing information. " + "Keep the core meaning, mathematical context, variables, and numbers exactly the same. " + "Focus only on rephrasing the *user's question* part provided. " + "Output *only* the rephrased question, without any introductory phrases like 'Here's the rephrased question:'." + ) + payload = json.dumps({ + "model": model, + "messages": [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": original_question} + ], + }) + headers = { + 'Accept': 'application/json', + 'Authorization': f'Bearer {api_key}', + 'User-Agent': 'HuggingFace Dataset Processing Script (Retry w/ Save)', + 'Content-Type': 'application/json' + } + time.sleep(random.uniform(REQUEST_DELAY_SECONDS * 0.8, REQUEST_DELAY_SECONDS * 1.2)) + + for attempt in range(retries): + # logging.debug(f"API call attempt {attempt + 1}/{retries} for: {original_question[:50]}...") + try: + conn = http.client.HTTPSConnection(host, timeout=60) # Increased timeout + conn.request("POST", path, payload, headers) + res = conn.getresponse() + status = res.status + data = res.read() + conn.close() + + if status == 200: + response_json = json.loads(data.decode("utf-8")) + if response_json.get("choices") and len(response_json["choices"]) > 0: + message = response_json["choices"][0].get("message") + if message and message.get("content"): + rephrased = message["content"].strip() + # Remove surrounding quotes more robustly + if len(rephrased) > 1: + if (rephrased.startswith('"') and rephrased.endswith('"')) or \ + (rephrased.startswith("'") and rephrased.endswith("'")): + rephrased = rephrased[1:-1] + # Handle cases like 'Rephrased: "..."' + if rephrased.lower().startswith(("rephrased:", "here's the rephrased question:")): + parts = rephrased.split(":", 1) + if len(parts) > 1: + potential_rephrased = parts[1].strip() + if (potential_rephrased.startswith('"') and potential_rephrased.endswith('"')) or \ + (potential_rephrased.startswith("'") and potential_rephrased.endswith("'")): + rephrased = potential_rephrased[1:-1] + else: + rephrased = potential_rephrased + + if rephrased and rephrased.strip().lower() != original_question.strip().lower(): + # logging.debug(f"Successfully rephrased: {rephrased[:80]}...") + return rephrased + elif not rephrased: + logging.warning(f"LLM returned empty/whitespace response for: {original_question[:50]}...") + return None + else: + logging.warning(f"LLM returned identical response for: {original_question[:50]}...") + return None # Treat identical as failure + logging.error(f"Unexpected API response structure: {data.decode('utf-8')}") + return None + elif status == 429: # Rate limit + retry_after_header = res.getheader('Retry-After', '5') + try: wait_time = int(retry_after_header) + except ValueError: wait_time = REQUEST_DELAY_SECONDS * (2 ** attempt) + random.uniform(1, 5) + logging.warning(f"Rate limit exceeded (HTTP {status}). Retrying after {wait_time:.2f} seconds...") + time.sleep(wait_time) + elif status >= 500: # Server error + wait_time = REQUEST_DELAY_SECONDS * (2 ** attempt) + random.uniform(1, 5) + logging.warning(f"Server error (HTTP {status}). Retrying after {wait_time:.2f} seconds...") + time.sleep(wait_time) + else: # Other client errors (4xx) - Don't retry these + logging.error(f"API Client Error: Status {status}, Response: {data.decode('utf-8')}") + return None + + except (http.client.HTTPException, ConnectionError, socket.gaierror, TimeoutError, socket.timeout) as e: + logging.error(f"Network/HTTP error during API call: {e}. Attempt {attempt + 1}/{retries}") + if attempt + 1 == retries: return None + wait_time = REQUEST_DELAY_SECONDS * (1.5 ** attempt) + random.uniform(1, 3) + logging.warning(f"Waiting {wait_time:.2f} seconds before retry...") + time.sleep(wait_time) + except json.JSONDecodeError as e: + logging.error(f"Failed to decode API response: {e}. Response snippet: {data[:200] if data else 'N/A'}") + if attempt + 1 == retries: return None + wait_time = REQUEST_DELAY_SECONDS * (2 ** attempt) + random.uniform(1, 5) + time.sleep(wait_time) + except Exception as e: + logging.error(f"An unexpected error occurred during API call: {e}", exc_info=True) + if attempt + 1 == retries: return None + wait_time = REQUEST_DELAY_SECONDS * (1.5 ** attempt) + random.uniform(1, 3) + logging.warning(f"Waiting {wait_time:.2f} seconds before retry...") + time.sleep(wait_time) + + logging.error(f"API call failed after {retries} retries for: {original_question[:50]}...") + return None + +# --- Dataset Processing Function (rephrase_query_entry) --- +# (No changes needed here) +def rephrase_query_entry(example): + processed_example = example.copy() + original_query_list = example.get("query") + processed_example['query_rephrased_status'] = 'processing_retry' + + if original_query_list is None: + processed_example['query_rephrased_status'] = 'skipped_missing_query_column' + processed_example['query_rephrased'] = example.get('query_rephrased') # Keep old value + return processed_example + if not isinstance(original_query_list, list): + processed_example['query_rephrased_status'] = 'skipped_query_not_list' + processed_example['query_rephrased'] = example.get('query_rephrased') # Keep old value + return processed_example + if not original_query_list: + processed_example['query_rephrased_status'] = 'skipped_query_list_empty' + processed_example['query_rephrased'] = example.get('query_rephrased') # Keep old value + return processed_example + + user_question = None + for message in original_query_list: + if isinstance(message, dict) and message.get("role") == "user": + content = message.get("content") + if isinstance(content, str) and content.strip(): + user_question = content + break + else: + processed_example['query_rephrased_status'] = 'skipped_invalid_user_content' + processed_example['query_rephrased'] = example.get('query_rephrased') # Keep old value + return processed_example + if not user_question: + processed_example['query_rephrased_status'] = 'skipped_no_user_content_found' + processed_example['query_rephrased'] = example.get('query_rephrased') # Keep old value + return processed_example + + # logging.info(f"Retrying: {user_question[:60]}...") + rephrased_query_content = call_llm_api(user_question, API_KEY, API_HOST, API_PATH, LLM_MODEL) + + if rephrased_query_content: + processed_example["query_rephrased"] = rephrased_query_content + processed_example['query_rephrased_status'] = 'success' + else: + # Keep the OLD 'query_rephrased' value if LLM call fails this time + processed_example['query_rephrased'] = example.get('query_rephrased') + processed_example['query_rephrased_status'] = 'failed_llm_call' + + return processed_example + +# --- Function to Save Dataset Atomically --- +# Saves to a temporary path then renames for safety. +def save_dataset_atomically(data_list, output_path, features): + """Saves the list of data dictionaries atomically using the correct schema.""" + if not data_list: + logging.info("No data provided for saving.") + return False + + temp_output_path = output_path + "_saving" # Temporary directory + final_output_path = output_path + + logging.info(f"Attempting to save {len(data_list)} examples to temp path {temp_output_path}...") + try: + # Create dataset from the list of dictionaries using the defined features + processed_dataset = Dataset.from_list(list(data_list), features=features) # Convert just in case + + # Ensure parent directory exists + os.makedirs(os.path.dirname(final_output_path), exist_ok=True) + + # Remove any previous temporary directory if it exists + if os.path.exists(temp_output_path): + logging.warning(f"Removing existing temporary save directory: {temp_output_path}") + shutil.rmtree(temp_output_path) # Use shutil for directories + + # Save the dataset to the temporary path + processed_dataset.save_to_disk(temp_output_path) + logging.info(f"Successfully saved dataset to temporary path: {temp_output_path}") + + # --- Atomic Rename --- + # Remove the final destination path if it exists + if os.path.exists(final_output_path): + logging.debug(f"Removing existing final destination directory before rename: {final_output_path}") + shutil.rmtree(final_output_path) + + # Rename the temporary path to the final path + os.rename(temp_output_path, final_output_path) + logging.info(f"Successfully moved temporary save to final path: {final_output_path}") + return True + + except Exception as e: + logging.error(f"Failed during atomic save process to {final_output_path}: {e}", exc_info=True) + # Attempt to clean up temporary directory if it still exists after failure + if os.path.exists(temp_output_path): + try: + shutil.rmtree(temp_output_path) + logging.info(f"Cleaned up temporary directory {temp_output_path} after error.") + except Exception as cleanup_e: + logging.error(f"Could not clean up temporary directory {temp_output_path} after error: {cleanup_e}") + # Fallback save attempt to JSON Lines (unchanged) + fallback_json_path = final_output_path + ".jsonl.failed_save" # Indicate it's a fallback + logging.warning(f"Attempting fallback save to JSON Lines file: {fallback_json_path}") + try: + with open(fallback_json_path, 'w', encoding='utf-8') as f: + for item in data_list: + f.write(json.dumps(item, ensure_ascii=False) + '\n') + logging.info(f"Successfully saved fallback JSON Lines file.") + except Exception as json_e: + logging.error(f"Fallback JSON save also failed: {json_e}", exc_info=True) + return False + +# --- Function to Check if Retry is Needed --- +# (No changes needed here) +def needs_retry(example): + rephrased = example.get('query_rephrased') + status = example.get('query_rephrased_status') + # Retry if rephrased is missing OR status is anything other than 'success' + # This ensures failed/skipped items are retried. + retry_flag = (rephrased is None) or (status != 'success') + return retry_flag + +# --- Main Execution --- +if __name__ == "__main__": + start_time = time.time() + logging.info("======================================================") + logging.info(f" Starting Dataset Processing - RETRY w/ PERIODIC SAVE") + logging.info(f" Saving progress every {SAVE_INTERVAL} processed items.") + logging.info("======================================================") + logging.info(f"Loading existing data from: {PROCESSED_DATA_PATH}") + logging.info(f"Intermediate and final output will be saved to: {FINAL_OUTPUT_PATH}") + + # --- Load Existing Processed Dataset --- + if not os.path.exists(PROCESSED_DATA_PATH): + logging.error(f"Existing data directory not found at '{PROCESSED_DATA_PATH}'. Cannot run retry mode.") + sys.exit(1) + + logging.info(f"Loading existing dataset from {PROCESSED_DATA_PATH}...") + try: + existing_dataset = Dataset.load_from_disk(PROCESSED_DATA_PATH) + # Get features *before* converting to list + dataset_features = existing_dataset.features + logging.info(f"Dataset features detected: {dataset_features}") + # Convert to a list of dictionaries for in-memory modification + results_list = existing_dataset.to_list() + total_examples = len(results_list) + logging.info(f"Loaded {total_examples} examples.") + except Exception as e: + logging.error(f"Failed to load dataset from {PROCESSED_DATA_PATH}: {e}", exc_info=True) + # Check if the final path exists from a previous run - maybe load that? + # For now, exiting is safer to avoid inconsistent states. + # if os.path.exists(FINAL_OUTPUT_PATH): + # logging.warning(f"Consider manually checking/using the existing data at {FINAL_OUTPUT_PATH}") + sys.exit(1) + + # --- Identify Indices to Retry --- + logging.info("Identifying examples needing retry...") + indices_to_retry = [ + i for i, example in enumerate(tqdm(results_list, desc="Checking examples")) if needs_retry(example) + ] + num_to_retry = len(indices_to_retry) + + if num_to_retry == 0: + logging.info("No examples found needing retry based on the criteria ('query_rephrased' is None or status != 'success').") + logging.info(f"Saving the existing dataset to the final location '{FINAL_OUTPUT_PATH}' as is...") + if not save_dataset_atomically(results_list, FINAL_OUTPUT_PATH, dataset_features): # Use atomic save + logging.error("Failed to save the dataset to the final location even though no retries were needed.") + sys.exit(0) + + logging.info(f"Identified {num_to_retry} examples to retry out of {total_examples}.") + + # --- Prepare for Concurrent Retries --- + processed_count_total = 0 # Total processed in this run + processed_since_last_save = 0 # Counter for periodic saving + last_save_time = time.time() # Track time for saving message + + logging.info("Starting concurrent retries with periodic saving...") + + # --- ThreadPoolExecutor for Concurrency --- + with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor: + # Submit tasks only for the identified indices + futures = { + executor.submit(rephrase_query_entry, results_list[i]): i + for i in indices_to_retry + } + + try: + # Initialize progress bar for retries + pbar = tqdm(total=num_to_retry, desc="Retrying examples", unit="example") + # Process futures as they complete + for future in concurrent.futures.as_completed(futures): + original_index = futures[future] # Get the original list index + try: + # Get the result (the updated dictionary) + updated_example_dict = future.result() + # --- IMMEDIATE UPDATE of the main list --- + results_list[original_index] = updated_example_dict + pbar.set_postfix({"LastStatus": updated_example_dict.get('query_rephrased_status', 'N/A')}, refresh=True) + + except Exception as exc: + # Catch potential exceptions *from* the rephrase_query_entry function + logging.error(f'Retry task for index {original_index} encountered an exception: {exc}', exc_info=True) + # Create an error placeholder and update the main list + error_placeholder = results_list[original_index].copy() # Start with original data + error_placeholder['query_rephrased_status'] = f'failed_retry_exception_{type(exc).__name__}' + # Keep the old query_rephrased value + results_list[original_index] = error_placeholder + pbar.set_postfix({"LastStatus": error_placeholder['query_rephrased_status']}, refresh=True) + + finally: + # Increment counters and update progress bar + processed_count_total += 1 + processed_since_last_save += 1 + pbar.update(1) + + # --- Periodic Save Check --- + if processed_since_last_save >= SAVE_INTERVAL: + current_time = time.time() + time_since_last = current_time - last_save_time + logging.info(f"\n--- Processed {processed_since_last_save} items (Total: {processed_count_total}/{num_to_retry}). Time since last save: {time_since_last:.1f}s. Saving progress... ---") + if save_dataset_atomically(results_list, FINAL_OUTPUT_PATH, dataset_features): + logging.info(f"--- Progress successfully saved to {FINAL_OUTPUT_PATH} ---") + processed_since_last_save = 0 # Reset counter + last_save_time = current_time + else: + logging.error(f"--- FAILED TO SAVE PROGRESS! Check errors above. Will retry saving later. ---") + # Don't reset the counter, maybe the next save will work + + except KeyboardInterrupt: + logging.warning("\nCtrl+C detected! Attempting final save...") + # Let the finally block handle the save + + except Exception as e: + logging.error(f"An unexpected error occurred during the main retry loop: {e}", exc_info=True) + logging.error("Attempting final save...") + # Let the finally block handle the save + + finally: + # --- This block executes after the loop finishes, OR if an exception/interrupt occurs --- + if 'pbar' in locals() and pbar is not None: + pbar.close() + + logging.info("--- Processing loop finished or interrupted. ---") + + # --- Final Save Attempt --- + # No need to update results_list again, it was updated incrementally. + logging.info(f"Attempting final save of the dataset ({len(results_list)} items) to: {FINAL_OUTPUT_PATH}") + if save_dataset_atomically(results_list, FINAL_OUTPUT_PATH, dataset_features): + logging.info("--- Final dataset state saved successfully. ---") + else: + logging.error(">>> FINAL SAVE FAILED! <<< Check logs. Fallback JSON file might exist.") + + # --- Final Verification (Optional but Recommended) --- + logging.info("------------------------------------------------------") + logging.info("Verification: Attempting to load final saved dataset...") + try: + final_reloaded_dataset = Dataset.load_from_disk(FINAL_OUTPUT_PATH) + logging.info(f"Successfully reloaded final dataset with {len(final_reloaded_dataset)} examples from {FINAL_OUTPUT_PATH}.") + + # Simple status count + status_counts = {} + none_rephrased_count = 0 + for ex in final_reloaded_dataset: + status = ex.get('query_rephrased_status', 'unknown_status') + status_counts[status] = status_counts.get(status, 0) + 1 + if ex.get('query_rephrased') is None or not str(ex.get('query_rephrased')).strip(): + none_rephrased_count += 1 + + logging.info("Final status counts:") + for status, count in sorted(status_counts.items()): + logging.info(f" - {status}: {count}") + + final_success = status_counts.get('success', 0) + final_failed = sum(count for st, count in status_counts.items() if st and (st.startswith('failed_') or st == 'processing_retry')) # Items potentially stuck + final_skipped = sum(count for st, count in status_counts.items() if st and st.startswith('skipped_')) + other_count = len(final_reloaded_dataset) - final_success - final_failed - final_skipped + + logging.info(f"Summary: Success={final_success}, Failed/Incomplete={final_failed}, Skipped={final_skipped}, Other={other_count}") + if none_rephrased_count > 0: + logging.warning(f"WARNING: {none_rephrased_count} items have None/empty 'query_rephrased' in the final dataset.") + if final_failed > 0: + logging.warning(f"WARNING: {final_failed} items did not reach 'success' or 'skipped' status.") + + + except FileNotFoundError: + logging.error(f"Verification failed: Final dataset directory not found at {FINAL_OUTPUT_PATH}. Final save likely failed.") + except Exception as e: + logging.error(f"Verification failed: Could not reload/verify final dataset from {FINAL_OUTPUT_PATH}: {e}", exc_info=True) + + # --- Script End --- + end_time = time.time() + logging.info("------------------------------------------------------") + logging.info(f"Script finished in {end_time - start_time:.2f} seconds.") + logging.info("======================================================") \ No newline at end of file diff --git a/r1-a/dataset/retts.py b/r1-a/dataset/retts.py new file mode 100644 index 0000000000000000000000000000000000000000..cd4bdbfa487e290010f48bc00ac4623382a6b0dc --- /dev/null +++ b/r1-a/dataset/retts.py @@ -0,0 +1,559 @@ +# -*- coding: utf-8 -*- +import os +import argparse +import torch +import re +import jiwer +from datasets import load_from_disk, concatenate_datasets, Dataset, Features, Value, Audio # Keep Audio for potential output type hint if needed +from transformers import pipeline +import logging +import time +import soundfile as sf # For checking validity +import librosa # For loading audio in batch function +import numpy as np +import collections +import pyarrow as pa + +# --- 配置日志 --- +log_formatter = logging.Formatter('%(asctime)s - %(levelname)s - [Shard %(shard_index)s] - %(message)s') +logger = logging.getLogger() +logger.setLevel(logging.INFO) +ch = logging.StreamHandler() +ch.setLevel(logging.INFO) +ch.setFormatter(log_formatter) +logger.addHandler(ch) +fh = None # File handler setup in main + +# --- 常量与参数定义 --- +MODEL_ID = "openai/whisper-large-v3" +DATASET_PATH = "/home/chenyifu/audio-r1/r1-a/final_dataset/preference_relative" # ADJUST IF NEEDED +OUTPUT_DIR = "/home/chenyifu/audio-r1/r1-a/final_dataset/preference_relative_processed_shards" # ADJUST IF NEEDED +LOG_DIR = os.path.join(OUTPUT_DIR, "logs") +NUM_SHARDS = 50 +MIN_AUDIO_DURATION_MS = 100 +TARGET_SR = 16000 # Whisper expected sample rate + +# --- 文本规范化函数 --- +def normalize_text(text): + if text is None: + return "" + text = str(text).lower() + text = re.sub(r'[^\w\s\u4e00-\u9fff,。!?、]', '', text) + text = re.sub(r'\s+', ' ', text).strip() + return text + +# --- 音频文件预检查函数 (检查路径) --- +# (This function remains largely the same as it already worked with paths) +def check_audio_file_validity(example, shard_idx_for_log=None): + is_valid = False + error_msg = "Unknown error" + duration_ms = 0 + audio_path = example.get("question_audio") # Directly get the path string + log_prefix = f"[Shard {shard_idx_for_log}] " if shard_idx_for_log is not None else "" + + if audio_path and isinstance(audio_path, str): + if os.path.exists(audio_path): + try: + info = sf.info(audio_path) + duration_ms = int(info.duration * 1000) + if info.samplerate > 0 and info.frames > 0: + if duration_ms >= MIN_AUDIO_DURATION_MS: + is_valid = True + error_msg = None + else: + error_msg = f"Audio duration {duration_ms}ms < minimum {MIN_AUDIO_DURATION_MS}ms" + else: + error_msg = "Invalid audio properties (samplerate/frames <= 0)" + except Exception as e: + logger.warning(f"{log_prefix}Cannot read info/validate file {audio_path}: {type(e).__name__}") + error_msg = f"Cannot read/validate audio: {type(e).__name__}" + else: + error_msg = "Audio file not found" + elif audio_path is None: + error_msg = "Audio path is missing or null" + else: + error_msg = f"Audio path is not a string (type: {type(audio_path).__name__})" + + return { + "audio_is_valid": is_valid, + "audio_check_error": error_msg, + "audio_duration_ms": duration_ms + # Don't add the original path back here, it's already in the dataset + } + + +# --- 核心处理函数 (批处理 - 加载音频路径) --- +def check_audio_quality_batch(batch, asr_pipeline, wer_threshold, target_sr, shard_idx_for_log=None): + log_prefix = f"[Shard {shard_idx_for_log}] " if shard_idx_for_log is not None else "" + results = {"asr_transcription": [], "wer": [], "is_bad_tts": [], "error_message": []} + original_texts = batch.get("question_text", []) + audio_paths = batch.get("question_audio", []) # Get list of paths + + num_samples_in_batch = len(audio_paths) + if not audio_paths or not original_texts or len(audio_paths) != len(original_texts): + logger.warning(f"{log_prefix}Batch inconsistency or empty data. Paths: {len(audio_paths)}, Text: {len(original_texts)}") + num_samples = max(len(audio_paths), len(original_texts)) + results["asr_transcription"] = [""] * num_samples + results["wer"] = [1.0] * num_samples + results["is_bad_tts"] = [True] * num_samples + results["error_message"] = ["Inconsistent batch data or missing paths/text"] * num_samples + return results + + batch_load_start_time = time.time() + loaded_audios = [] + load_errors = [None] * num_samples_in_batch # Track loading errors per sample + + # --- 加载批次中的所有音频 --- + for i, path in enumerate(audio_paths): + try: + if not path or not isinstance(path, str): + raise ValueError("Invalid audio path") + # Load using librosa, force mono, resample to target_sr + audio_array, sample_rate = librosa.load(path, sr=target_sr, mono=True) + loaded_audios.append(audio_array) + except Exception as e: + logger.warning(f"{log_prefix}Failed to load audio file '{path}': {type(e).__name__}. Skipping for ASR.") + loaded_audios.append(None) # Use None as placeholder for failed loads + load_errors[i] = f"Audio load failed: {type(e).__name__}" + + batch_load_end_time = time.time() + logger.debug(f"{log_prefix}Loaded {len([a for a in loaded_audios if a is not None])}/{num_samples_in_batch} audios in {batch_load_end_time - batch_load_start_time:.2f} sec.") + + # Filter out None placeholders before sending to ASR pipeline? + # Option 1: Send only valid audios (might complicate matching results back) + # Option 2: Send list including None/empty arrays, let pipeline handle (or pre-handle) + # Let's try Option 2 with pre-handling: Replace None with empty array for pipeline input + pipeline_inputs = [] + valid_indices = [] # Track indices of samples sent to pipeline + for i, audio_data in enumerate(loaded_audios): + if audio_data is not None and len(audio_data) > 0: # Check if loading succeeded and audio not empty + pipeline_inputs.append(audio_data) + valid_indices.append(i) + # else: keep load_errors[i] message + + asr_results_list = [None] * num_samples_in_batch # Initialize results list matching original batch size + + # --- ASR 推理 (仅对成功加载的音频) --- + if pipeline_inputs: # Only run pipeline if there are valid audios + batch_asr_start_time = time.time() + try: + # Pass the list of NumPy arrays directly to the pipeline + asr_outputs = asr_pipeline(pipeline_inputs, generate_kwargs={"language": "zh", "task": "transcribe"}) + + if not isinstance(asr_outputs, list): + asr_outputs = [asr_outputs] # Ensure it's a list + + # Map results back to original batch positions using valid_indices + if len(asr_outputs) == len(valid_indices): + for idx, result in zip(valid_indices, asr_outputs): + asr_results_list[idx] = result # Place result at the correct original index + else: + logger.error(f"{log_prefix}ASR output count ({len(asr_outputs)}) mismatch with valid input count ({len(valid_indices)}). Marking all in batch as error.") + # Mark all samples in the batch with an error if counts mismatch + for i in range(num_samples_in_batch): + if load_errors[i] is None: # If loading didn't fail, mark as ASR mismatch + load_errors[i] = "ASR count mismatch error" + + # --- Error Handling for ASR Pipeline --- + # Catch errors specifically from the pipeline call + except ValueError as ve: # e.g., internal batching errors if any remain + logger.error(f"{log_prefix}ValueError during ASR pipeline processing: {ve}", exc_info=True) + for idx in valid_indices: # Mark only those sent to pipeline as failed + asr_results_list[idx] = "ERROR: ASR ValueError" # Placeholder or error indicator + if load_errors[idx] is None: load_errors[idx] = f"ASR ValueError: {str(ve)[:100]}" + except torch.cuda.OutOfMemoryError: + logger.error(f"{log_prefix}CUDA OutOfMemoryError during ASR batch processing.") + torch.cuda.empty_cache() + for idx in valid_indices: + asr_results_list[idx] = "ERROR: ASR OOM" + if load_errors[idx] is None: load_errors[idx] = "ASR CUDA OOM" + except Exception as e: + logger.error(f"{log_prefix}Exception during ASR pipeline processing: {e}", exc_info=True) + for idx in valid_indices: + asr_results_list[idx] = "ERROR: ASR Exception" + if load_errors[idx] is None: load_errors[idx] = f"ASR Exception: {str(e)[:100]}" + + batch_asr_end_time = time.time() + logger.debug(f"{log_prefix}ASR processed {len(valid_indices)} audios in {batch_asr_end_time - batch_asr_start_time:.2f} sec.") + + # --- 计算 WER (遍历原始批次大小) --- + for i in range(num_samples_in_batch): + transcription = "" + wer = 1.0 # Default to max error + is_bad = True + error_msg = load_errors[i] # Start with potential loading error + + asr_result = asr_results_list[i] + + if error_msg is None: # If no loading error, proceed with ASR result + if isinstance(asr_result, dict) and "text" in asr_result: + transcription = asr_result["text"] + original_text = original_texts[i] + + norm_original = normalize_text(original_text) + norm_transcription = normalize_text(transcription) + + if not norm_original: + wer = 1.0 if norm_transcription else 0.0 + is_bad = True if norm_transcription else False + error_msg = "Original text normalized to empty" if is_bad else "Original text normalized to empty, transcription also empty" + else: + try: + wer = jiwer.wer(norm_original, norm_transcription) + wer = min(wer, 1.0) # Clamp WER + is_bad = wer > wer_threshold + except ValueError as e: + wer = 1.0 + is_bad = True + logger.warning(f"{log_prefix}Jiwer WER calculation error for idx {i}. Setting WER to 1.0. Error: {e}") + error_msg = f"WER calculation error: {e}" + except Exception as e: + wer = 1.0 + is_bad = True + logger.error(f"{log_prefix}Unexpected error during WER calculation idx {i}: {e}", exc_info=True) + error_msg = f"Unexpected WER error: {e}" + elif isinstance(asr_result, str) and "ERROR:" in asr_result: + # Handle error strings passed from ASR exception handling + error_msg = asr_result + wer = 1.0 + is_bad = True + else: + # ASR didn't run (load failed) or returned unexpected format + # error_msg should already be set from load_errors + # If error_msg is somehow still None, set a generic one + if error_msg is None: + error_msg = "ASR did not produce valid output" + wer = 1.0 + is_bad = True + + results["asr_transcription"].append(transcription) + results["wer"].append(wer) + results["is_bad_tts"].append(is_bad) + results["error_message"].append(error_msg) + + return results + + +# --- 统计信息记录函数 --- +# (This function remains the same, as it operates on the processed shard data) +def log_shard_statistics(processed_shard, shard_index, wer_threshold, processing_time): + log_prefix = f"[Shard {shard_index}] " + logger.info(f"{log_prefix}--- Shard {shard_index} Statistics ---") + # ... (rest of the function is identical to the previous version) ... + total_samples = len(processed_shard) + logger.info(f"{log_prefix}Total samples processed in this shard: {total_samples}") + if total_samples == 0: + logger.info(f"{log_prefix}Shard was empty, no statistics to report.") + logger.info(f"{log_prefix}--- End Shard {shard_index} Statistics ---") + return + + logger.info(f"{log_prefix}Processing time for this shard: {processing_time:.2f} seconds") + if processing_time > 0: + logger.info(f"{log_prefix}Overall processing speed: {total_samples / processing_time:.2f} samples/sec") + logger.info(f"{log_prefix}WER threshold used: {wer_threshold}") + + required_cols = ['is_bad_tts', 'wer', 'error_message', 'question_text', 'asr_transcription'] + if not all(col in processed_shard.column_names for col in required_cols): + logger.error(f"{log_prefix}Processed shard is missing required columns for statistics ({required_cols}). Skipping detailed stats.") + logger.info(f"{log_prefix}--- End Shard {shard_index} Statistics ---") + return + + try: + bad_tts_count = sum(processed_shard['is_bad_tts']) + bad_tts_percentage = (bad_tts_count / total_samples) * 100 if total_samples > 0 else 0 + logger.info(f"{log_prefix}Bad TTS samples (WER > {wer_threshold} or Error): {bad_tts_count} ({bad_tts_percentage:.2f}%)") + logger.info(f"{log_prefix}Good TTS samples (WER <= {wer_threshold}): {total_samples - bad_tts_count} ({100 - bad_tts_percentage:.2f}%)") + + wer_scores = [w for w in processed_shard['wer'] if w is not None and not np.isnan(w)] + if wer_scores: + logger.info(f"{log_prefix}WER Score Distribution (for samples where WER could be calculated):") + logger.info(f"{log_prefix} Count: {len(wer_scores)}") + logger.info(f"{log_prefix} Min: {np.min(wer_scores):.4f}") + logger.info(f"{log_prefix} Max: {np.max(wer_scores):.4f}") # Should be <= 1.0 now + logger.info(f"{log_prefix} Mean: {np.mean(wer_scores):.4f}") + logger.info(f"{log_prefix} Median: {np.median(wer_scores):.4f}") + q25, q75 = np.percentile(wer_scores, [25, 75]) + logger.info(f"{log_prefix} 25th Percentile: {q25:.4f}") + logger.info(f"{log_prefix} 75th Percentile: {q75:.4f}") + else: + logger.info(f"{log_prefix}WER Score Distribution: No valid WER scores found.") + + error_messages = [msg for msg in processed_shard['error_message'] if msg] + if error_messages: + error_counts = collections.Counter(error_messages) + logger.info(f"{log_prefix}Error Message Summary (Top 10):") + for msg, count in error_counts.most_common(10): + logger.info(f"{log_prefix} - \"{msg}\": {count} occurrences") + if len(error_counts) > 10: + logger.info(f"{log_prefix} ... ({len(error_counts) - 10} more error types)") + else: + logger.info(f"{log_prefix}Error Message Summary: No processing errors recorded.") + + logger.info(f"\n{log_prefix}--- Example Good TTS Samples (WER <= {wer_threshold}) ---") + # Use select for potentially large datasets, disable caching for filter + good_samples_indices = [i for i, bad in enumerate(processed_shard['is_bad_tts']) if not bad] + num_good_to_show = min(5, len(good_samples_indices)) + if num_good_to_show > 0: + # Select the samples using indices; this is faster than filter for small selects + good_samples_view = processed_shard.select(good_samples_indices[:num_good_to_show]) + for i in range(num_good_to_show): + sample = good_samples_view[i] + logger.info(f"{log_prefix} Example {i+1}:") + logger.info(f"{log_prefix} Original Text: {sample['question_text']}") + logger.info(f"{log_prefix} ASR Transcript: {sample['asr_transcription']}") + logger.info(f"{log_prefix} WER: {sample['wer']:.4f}") + logger.info(f"{log_prefix} Audio Path: {sample['question_audio']}") # Show path + else: + logger.info(f"{log_prefix} No good samples found in this shard.") + + + logger.info(f"\n{log_prefix}--- Example Bad TTS Samples (WER > {wer_threshold} or Error) ---") + bad_samples_indices = [i for i, bad in enumerate(processed_shard['is_bad_tts']) if bad] + num_bad_to_show = min(5, len(bad_samples_indices)) + if num_bad_to_show > 0: + bad_samples_view = processed_shard.select(bad_samples_indices[:num_bad_to_show]) + for i in range(num_bad_to_show): + sample = bad_samples_view[i] + logger.info(f"{log_prefix} Example {i+1}:") + logger.info(f"{log_prefix} Original Text: {sample['question_text']}") + logger.info(f"{log_prefix} ASR Transcript: {sample['asr_transcription']}") + logger.info(f"{log_prefix} WER: {sample['wer']:.4f}") + logger.info(f"{log_prefix} Error Msg: {sample['error_message']}") + logger.info(f"{log_prefix} Audio Path: {sample['question_audio']}") # Show path + else: + logger.info(f"{log_prefix} No bad samples found in this shard.") + + except Exception as e: + logger.error(f"{log_prefix}Error generating statistics: {e}", exc_info=True) + + logger.info(f"{log_prefix}--- End Shard {shard_index} Statistics ---") + + +# --- 主函数 --- +def main(): + global fh + parser = argparse.ArgumentParser(description="Process a shard of the dataset using Whisper ASR, loading audio from paths.") + # ... (Argument parsing remains the same) ... + parser.add_argument("--shard_index", type=int, required=True, help=f"Index of the shard to process (0 to {NUM_SHARDS-1}).") + parser.add_argument("--gpu_id", type=int, required=True, help="GPU ID to use for this process.") + parser.add_argument("--wer_threshold", type=float, default=0.4, help="WER threshold to mark TTS as bad.") + parser.add_argument("--pipeline_batch_size", type=int, default=8, help="Internal batch size for the ASR pipeline.") + parser.add_argument("--map_batch_size", type=int, default=16, help="Batch size for the datasets.map function (how many rows passed to batch func).") + parser.add_argument("--num_check_workers", type=int, default=4, help="Number of workers for audio pre-check map.") + args = parser.parse_args() + shard_index = args.shard_index + # ... (Rest of argument setup, logging setup, GPU setup - same as before) ... + gpu_id = args.gpu_id + wer_threshold = args.wer_threshold + pipeline_batch_size = args.pipeline_batch_size + map_batch_size = args.map_batch_size + num_check_workers = args.num_check_workers + + os.makedirs(LOG_DIR, exist_ok=True) + log_file = os.path.join(LOG_DIR, f"shard_{shard_index}_gpu_{gpu_id}.log") + fh = logging.FileHandler(log_file, mode='w') + fh.setLevel(logging.INFO) + fh.setFormatter(log_formatter) + logger.addHandler(fh) + + old_factory = logging.getLogRecordFactory() + def record_factory(*args, **kwargs): + record = old_factory(*args, **kwargs) + record.shard_index = shard_index + return record + logging.setLogRecordFactory(record_factory) + + os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) + device = f"cuda:0" + logger.info(f"Process started for Shard {shard_index} on GPU {gpu_id} (logical device {device})") + logger.info(f"Arguments: {args}") + + processed_shard = None + processing_time = 0 + + try: + # --- 加载完整数据集 --- + logger.info(f"Loading dataset from {DATASET_PATH}") + try: + full_ds = load_from_disk(DATASET_PATH) + breakpoint() + logger.info(f"Full dataset loaded with {full_ds.num_rows} rows.") + # Check the feature type of question_audio - SHOULD BE string + if 'question_audio' not in full_ds.features: + logger.error("Dataset loaded, but required 'question_audio' column is missing!") + return + logger.info(f"Feature 'question_audio': {full_ds.features['question_audio']}") + if not isinstance(full_ds.features['question_audio'], Value) or full_ds.features['question_audio'].dtype != 'string': + logger.warning(f"'question_audio' column type is not string ({full_ds.features['question_audio']}). Attempting to proceed, but expecting paths.") + + except Exception as e: + logger.error(f"Failed to load dataset: {e}", exc_info=True) + return + + # --- 数据预处理:检查音频文件有效性 (on paths) --- + logger.info(f"Checking audio file validity (min duration: {MIN_AUDIO_DURATION_MS}ms)...") + check_features = Features({ + **full_ds.features, + 'audio_is_valid': Value('bool'), + 'audio_check_error': Value('string'), + 'audio_duration_ms': Value('int64') + }) + num_check_workers = max(1, min(num_check_workers, os.cpu_count())) + logger.info(f"Using {num_check_workers} workers for audio check.") + full_ds_checked = full_ds.map( + check_audio_file_validity, + num_proc=num_check_workers, + features=check_features, + batched=False, + fn_kwargs={"shard_idx_for_log": shard_index} + ) + logger.info("Audio validity check complete.") + + # --- 过滤掉无效音频 --- + original_count = len(full_ds_checked) + valid_audio_ds = full_ds_checked.filter( + lambda x: x['audio_is_valid'], + num_proc=num_check_workers, + load_from_cache_file=False + ) + filtered_count = original_count - len(valid_audio_ds) + logger.info(f"Filtered out {filtered_count} samples based on path validity/duration. Kept {len(valid_audio_ds)} samples.") + + # Log filtering reasons (same as before) + if filtered_count > 0: + # Avoid running another potentially slow filter just for logging + logger.warning("Logging top filtering reasons (based on initial check results, sample limit applies if dataset large)...") + try: + error_reasons = collections.Counter(r['audio_check_error'] for r in full_ds_checked.filter(lambda x: not x['audio_is_valid'], load_from_cache_file=False).select(range(min(1000, filtered_count)))) + for reason, count in error_reasons.most_common(10): + if reason: # Don't log None reasons if any slip through + logger.warning(f" - {reason}: {count} samples") + except Exception as log_e: + logger.warning(f"Could not retrieve filtering reasons: {log_e}") + + + if len(valid_audio_ds) == 0: + logger.error("No valid audio samples found after filtering. Exiting.") + return + + # --- !! REMOVED cast_column step !! --- + # The 'question_audio' column remains as paths in valid_audio_ds + + # --- 获取当前进程需要处理的分片 --- + logger.info(f"Creating shard {shard_index} from valid audio data (paths)...") + ds_shard = valid_audio_ds.shard(num_shards=NUM_SHARDS, index=shard_index, contiguous=True) + logger.info(f"Shard {shard_index} created with {ds_shard.num_rows} rows.") + # Log features to confirm 'question_audio' is still string + logger.info(f"Shard features: {ds_shard.features}") + + + if ds_shard.num_rows == 0: + logger.warning(f"Shard {shard_index} is empty after sharding. Saving empty structure and exiting process.") + # Define empty output features (keeping original path column) + final_features = Features({ + **ds_shard.features, # Includes original columns like question_audio (path) + 'asr_transcription': Value('string'), + 'wer': Value('float32'), + 'is_bad_tts': Value('bool'), + 'error_message': Value('string') + }) + # Remove check columns from features before creating empty table + final_features.pop('audio_is_valid', None) + final_features.pop('audio_check_error', None) + final_features.pop('audio_duration_ms', None) + + shard_output_path = os.path.join(OUTPUT_DIR, f"shard_{shard_index}") + os.makedirs(shard_output_path, exist_ok=True) + try: + empty_table = pa.Table.from_pydict({}, schema=final_features.arrow_schema) + empty_ds = Dataset(arrow_table=empty_table) + empty_ds.save_to_disk(shard_output_path) + logger.info(f"Saved empty dataset structure for shard {shard_index}.") + except Exception as save_e: + logger.error(f"Could not save empty dataset structure for shard {shard_index}: {save_e}") + processed_shard = empty_ds # Set processed_shard for stats + return # Exit after handling empty shard + + # --- 加载ASR Pipeline --- + logger.info(f"Loading ASR pipeline {MODEL_ID} on {device}...") + try: + asr_pipeline = pipeline( + "automatic-speech-recognition", + model=MODEL_ID, + torch_dtype=torch.float16, + device=device, + batch_size=pipeline_batch_size # Pipeline's internal batch size + ) + logger.info(f"ASR pipeline loaded successfully with internal batch size {pipeline_batch_size}.") + except Exception as e: + logger.error(f"Failed to load ASR pipeline: {e}", exc_info=True) + return + + # --- 使用 map 处理分片数据 --- + logger.info(f"Starting processing shard {shard_index} with map batch size {map_batch_size} and WER threshold {wer_threshold}...") + start_time = time.time() + # Define output features: Keep original columns + add new ones + output_features = Features({ + **ds_shard.features, # Keep original columns (incl. question_audio path) + 'asr_transcription': Value('string'), + 'wer': Value('float32'), + 'is_bad_tts': Value('bool'), + 'error_message': Value('string') + }) + # Remove check columns from output features + output_features.pop('audio_is_valid', None) + output_features.pop('audio_check_error', None) + output_features.pop('audio_duration_ms', None) + + processed_shard = ds_shard.map( + check_audio_quality_batch, + batched=True, + batch_size=map_batch_size, # map's batch size (rows passed to func) + fn_kwargs={ + "asr_pipeline": asr_pipeline, + "wer_threshold": wer_threshold, + "target_sr": TARGET_SR, + "shard_idx_for_log": shard_index + }, + features=output_features, # Define output schema + load_from_cache_file=False, # Disable caching + remove_columns=['audio_is_valid', 'audio_check_error', 'audio_duration_ms'] # Remove check columns during map + ) + end_time = time.time() + processing_time = end_time - start_time + logger.info(f"Shard {shard_index} processing finished in {processing_time:.2f} seconds.") + logger.info(f"Processed shard {shard_index} has columns: {processed_shard.column_names}") + + # --- 保存处理后的分片 --- + # No need to remove check columns here, done in map + shard_output_path = os.path.join(OUTPUT_DIR, f"shard_{shard_index}") + logger.info(f"Saving processed shard {shard_index} to {shard_output_path}...") + os.makedirs(OUTPUT_DIR, exist_ok=True) + try: + processed_shard.save_to_disk(shard_output_path) + logger.info(f"Shard {shard_index} saved successfully.") + except Exception as e: + # Check specifically for Arrow serialization issues if they occur + logger.error(f"Failed to save processed shard {shard_index} to {shard_output_path}: {e}", exc_info=True) + # The IndexError related to soundfile should NOT happen now + + finally: + # --- 记录统计信息 --- + if processed_shard is not None: + log_shard_statistics(processed_shard, shard_index, wer_threshold, processing_time) + else: + logger.warning("Processing did not complete or failed early. No statistics to log.") + + logger.info(f"Process for Shard {shard_index} on GPU {gpu_id} finished.") + if fh: + logger.removeHandler(fh) + fh.close() + +if __name__ == "__main__": + # Add librosa to requirements check potentially + try: + import librosa + except ImportError: + print("Error: librosa is required. Please install it using: pip install librosa") + exit(1) + main() \ No newline at end of file diff --git a/r1-a/dataset/sciq.py b/r1-a/dataset/sciq.py new file mode 100644 index 0000000000000000000000000000000000000000..d98b370b500595bc46d6e1c85a9a1f352c7958bd --- /dev/null +++ b/r1-a/dataset/sciq.py @@ -0,0 +1,176 @@ +import os +import random +import torch +import torchaudio +from datasets import load_dataset, Dataset +import sys +from tqdm import tqdm + +sys.path.append('/root/autodl-tmp/CosyVoice') +from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2 +from cosyvoice.utils.file_utils import load_wav + +# ------------------------ +# 配置参数 +# ------------------------ +COMMON_VOICE_LANGUAGE = "en" +DATASET_NAME = "sciq" # 目标数据集:SciQ +OUTPUT_DATASET_PATH = './sciq_with_audio' +SAMPLE_RATE = 16000 + +# ------------------------ +# 辅助函数 +# ------------------------ +def get_random_prompt(common_voice_dataset, sample_rate=SAMPLE_RATE): + """ + 从 VoxPopuli (替代原 Common Voice) 数据集中随机抽取一条语音及对应文本作为 prompt。 + """ + idx = random.randint(0, len(common_voice_dataset) - 1) + sample = common_voice_dataset.select([idx])[0] + audio = sample['audio'] + waveform = torch.tensor(audio['array'], dtype=torch.float32) + sr = audio['sampling_rate'] + if sr != sample_rate: + resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate) + waveform = resampler(waveform) + return waveform.unsqueeze(0), sample['raw_text'] + +def text_to_audio(query_text, cosyvoice, common_voice_dataset, stream=False): + """ + 利用 CosyVoice2 模型将输入文本转换为语音,采用随机 prompt 进行零样本推理。 + """ + try: + prompt_speech, prompt_text = get_random_prompt(common_voice_dataset, sample_rate=SAMPLE_RATE) + # 可选:保存 prompt.wav 以做调试 + # torchaudio.save('prompt.wav', prompt_speech, SAMPLE_RATE) + + all_speech = [] + for i, j in enumerate(cosyvoice.inference_zero_shot( + query_text, + prompt_text, + prompt_speech, + stream=stream, + text_frontend=False + )): + all_speech.append(j['tts_speech']) + + # 将所有生成的语音片段拼接在一起 + combined_speech = torch.cat(all_speech, dim=-1) + sample_rate_val = cosyvoice.sample_rate + + return { + 'audio_tensor': combined_speech, + 'sample_rate': sample_rate_val + } + except Exception as e: + print(f"Error converting text to audio: {e}") + return None + +def process_example(example, cosyvoice, common_voice_dataset, sample_rate=SAMPLE_RATE): + """ + 针对 SciQ 数据集中的单个样本进行 TTS 处理。 + 假设我们只对 sample['question'] 做 TTS。 + """ + query = example['question'] # 可根据需要修改要转换的文本字段 + audio_result = text_to_audio(query, cosyvoice, common_voice_dataset, stream=False) + if audio_result is not None: + return { + 'audio_tensor': audio_result['audio_tensor'], + 'sample_rate': audio_result['sample_rate'] + } + else: + return None + +# ------------------------ +# 数据加载与模型初始化 +# ------------------------ +print("Loading VoxPopuli (as Common Voice) dataset...") +common_voice = load_dataset("facebook/voxpopuli", "en", split='train') +print(f"Total VoxPopuli {COMMON_VOICE_LANGUAGE} samples: {len(common_voice)}") + +print("Initializing CosyVoice2 model...") +cosyvoice = CosyVoice2( + '/root/autodl-tmp/CosyVoice/pretrained_models/CosyVoice2-0.5B', # 替换为实际模型路径 + load_jit=True, + load_trt=False, + fp16=False +) + +print("Loading SciQ dataset...") +dataset = load_dataset("allenai/sciq") + +# 创建输出目录 +os.makedirs(OUTPUT_DATASET_PATH, exist_ok=True) + +# ------------------------ +# 主处理循环 +# ------------------------ +final_dataset_dict = {} # 存放各 split 最终处理后的数据 + +for split_name, split_dataset in dataset.items(): + print(f"Processing split: {split_name} with {len(split_dataset)} examples") + split_output_dir = os.path.join(OUTPUT_DATASET_PATH, split_name) + os.makedirs(split_output_dir, exist_ok=True) + + # 用于断点续跑的进度记录 + progress_file = os.path.join(split_output_dir, "progress.txt") + start_index = 0 + if os.path.exists(progress_file): + try: + with open(progress_file, "r") as f: + start_index = int(f.read().strip()) + print(f"Resuming split '{split_name}' from sample index {start_index}") + except Exception as e: + print(f"读取进度文件失败:{e}") + + final_samples = [] # 用于存储处理后数据 + + # 遍历处理每条样本 + for i in tqdm(range(len(split_dataset)), desc=f"Processing {split_name}"): + # 如果已处理过,就直接跳过并仅把已存在文件的元信息记入 final_samples + if i < start_index: + sample = split_dataset[i] + wav_path = os.path.join(split_output_dir, f"{i}.wav") + if os.path.exists(wav_path): + # 保留所有原始字段 + 音频路径 + sample_dict = {k: sample[k] for k in sample.keys()} + sample_dict["audio_filepath"] = wav_path + final_samples.append(sample_dict) + continue + + sample = split_dataset[i] + result = process_example(sample, cosyvoice, common_voice, sample_rate=SAMPLE_RATE) + + if result is not None: + audio_tensor = result['audio_tensor'] + if audio_tensor.dim() == 1: + audio_tensor = audio_tensor.unsqueeze(0) + sample_rate_val = result['sample_rate'] + + output_wav_path = os.path.join(split_output_dir, f"{i}.wav") + try: + torchaudio.save(output_wav_path, audio_tensor, sample_rate_val) + except Exception as e: + print(f"Failed to save wav for sample {i}: {e}") + continue + + # 保留所有原始字段 + 生成的音频路径 + sample_dict = {k: sample[k] for k in sample.keys()} + sample_dict["audio_filepath"] = output_wav_path + final_samples.append(sample_dict) + else: + print(f"Sample {i} processing failed, no audio generated.") + + # 更新进度记录 + with open(progress_file, "w") as f: + f.write(str(i + 1)) + + # 生成 Hugging Face Dataset 并落盘 + final_dataset_obj = Dataset.from_list(final_samples) + final_dataset_save_path = os.path.join(split_output_dir, "final_dataset") + final_dataset_obj.save_to_disk(final_dataset_save_path) + + print(f"Finished processing split: {split_name} with {len(final_samples)} final samples.") + final_dataset_dict[split_name] = final_dataset_obj + +print("所有分割处理完毕,最终数据集已保存。") diff --git a/r1-a/dataset/shp.py b/r1-a/dataset/shp.py new file mode 100644 index 0000000000000000000000000000000000000000..f9a9ddbc231d907bc20f7c10f486a6c76b402f48 --- /dev/null +++ b/r1-a/dataset/shp.py @@ -0,0 +1,148 @@ +import re +import os # 确保导入 os 用于保存 +from datasets import load_dataset, Dataset +from tqdm.auto import tqdm # 用于显示进度条 + +# --- 可调整的过滤参数 --- +# (保持不变) +SCORE_RATIO_THRESHOLD = 2.0 +MIN_CHOSEN_SCORE = 3 +MIN_HISTORY_WORDS = 10 +MAX_HISTORY_WORDS = 150 # 调整为 150 +MAX_URLS = 0 # 调整为 0 +MAX_NEWLINES = 5 +FORBIDDEN_PATTERNS = [ + r"```.*```", + r"\|.*\|.*\|", +] +MIN_RESPONSE_WORDS = 10 + +# --- 脚本主逻辑 --- + +def is_tts_friendly(text): + """检查文本是否大致适合 TTS""" + # (保持不变) + word_count = len(text.split()) + if not (MIN_HISTORY_WORDS <= word_count <= MAX_HISTORY_WORDS): + return False + if text.count('http') > MAX_URLS: # 使用调整后的 MAX_URLS + return False + if text.count('\n') > MAX_NEWLINES: + return False + for pattern in FORBIDDEN_PATTERNS: + if re.search(pattern, text, re.DOTALL): + return False + return True + +def filter_shp2_train_dataset(dataset_name="stanfordnlp/shp-2"): # 函数名稍作修改以反映其目的 + """ + 加载并过滤 SHP-2 数据集的 'train' split, + 返回高质量、适合 TTS 的偏好对。 + """ + split_to_process = 'train' # <--- 指定只处理 'train' split + print(f"加载数据集: {dataset_name}, split: {split_to_process}...") + + try: + # --- 修改点 1: 直接加载指定的 split --- + train_dataset = load_dataset(dataset_name, split=split_to_process) + print(f"'{split_to_process}' split 加载完成。") + except Exception as e: + print(f"错误:无法加载数据集 {dataset_name} 的 '{split_to_process}' split。") + print(f"错误详情: {e}") + return [] # 返回空列表表示失败 + + filtered_data = [] + seen_histories = set() # 用于跟踪已经添加的 history,确保唯一性 + + print(f"\n开始处理 '{split_to_process}' split...") + # --- 修改点 2: 直接迭代加载的 train_dataset --- + for example in tqdm(train_dataset, desc=f"过滤 {split_to_process} split"): + history = example.get("history") + human_ref_A = example.get("human_ref_A") + human_ref_B = example.get("human_ref_B") + labels = example.get("labels") + score_A = example.get("score_A") + score_B = example.get("score_B") + score_ratio = example.get("score_ratio") + domain = example.get("domain") + + # 基本检查 (保持不变) + if not all([history, human_ref_A, human_ref_B, labels is not None, + score_A is not None, score_B is not None, score_ratio is not None, domain]): + continue + + # 确保 history 未被处理过 (保持不变) + if history in seen_histories: + continue + + # 确定 chosen 和 reject (保持不变) + try: + label_int = int(labels) + if label_int == 1: + chosen = human_ref_A + reject = human_ref_B + chosen_score = score_A + elif label_int == 0: + chosen = human_ref_B + reject = human_ref_A + chosen_score = score_B + else: + continue + except (ValueError, TypeError): + continue + + # --- 应用过滤条件 (保持不变) --- + if score_ratio is None or not isinstance(score_ratio, (int, float)) or score_ratio < SCORE_RATIO_THRESHOLD: + continue + if chosen_score is None or not isinstance(chosen_score, (int, float)) or chosen_score < MIN_CHOSEN_SCORE: + continue + if not is_tts_friendly(history): + continue + if len(chosen.split()) < MIN_RESPONSE_WORDS or len(reject.split()) < MIN_RESPONSE_WORDS: + continue + + # --- 如果所有过滤条件都通过 (保持不变) --- + filtered_data.append({ + "query": history, + "chosen": chosen, + "reject": reject, + "domain": domain, + }) + seen_histories.add(history) + + print(f"\n过滤完成。从 '{split_to_process}' split 中总共筛选出 {len(filtered_data)} 条高质量样本。") + return filtered_data + +# --- 主程序 --- +if __name__ == "__main__": + # 执行过滤 (调用修改后的函数) + filtered_examples = filter_shp2_train_dataset() + + if filtered_examples: + # 将过滤后的数据转换为 Hugging Face Dataset 对象 (保持不变) + filtered_dataset = Dataset.from_list(filtered_examples) + + # 保存过滤后的数据集 (保持不变) + output_path = "./shp2_filtered_tts_high_quality_train_only" # 修改输出路径以反映内容 + print(f"正在保存过滤后的训练集数据到: {output_path}") + # 确保输出目录存在 + os.makedirs(os.path.dirname(output_path), exist_ok=True) # 如果 output_path 是目录,这行不需要 + filtered_dataset.save_to_disk(output_path) + print("数据集保存完成。") + + # 打印一些样本看看 (保持不变) + print("\n部分样本预览:") + # 从保存的 Dataset 加载并预览,确保保存成功 + try: + loaded_dataset = Dataset.load_from_disk(output_path) + for i in range(min(5, len(loaded_dataset))): + sample = loaded_dataset[i] + print(f"--- 样本 {i+1} ---") + print(f"Domain: {sample['domain']}") + print(f"Query: {sample['query'][:200]}...") + print(f"Chosen: {sample['chosen'][:200]}...") + except Exception as e: + print(f"加载预览样本时出错: {e}") # 增加错误处理 + + else: + print("没有找到符合条件的样本,请检查过滤参数设置或确认 'train' split 是否存在且包含数据。") \ No newline at end of file diff --git a/r1-a/dataset/shp_tts.py b/r1-a/dataset/shp_tts.py new file mode 100644 index 0000000000000000000000000000000000000000..25f2d372a520fc32b52fc05f1ae4bc5d686cdfa7 --- /dev/null +++ b/r1-a/dataset/shp_tts.py @@ -0,0 +1,494 @@ +# --- SET CUDA DEVICE --- +# Method 1: Set environment variable BEFORE importing torch/cosyvoice +# This makes only GPU 1 visible to the script, appearing as 'cuda:0' internally. +import os +os.environ["CUDA_VISIBLE_DEVICES"] = "0" +# --- End CUDA Device Setting --- + +import random +import torch +import torchaudio +# Make sure necessary types are imported +from datasets import load_dataset, Dataset, load_from_disk, Features, Value +import sys +from tqdm import tqdm +import time +import shutil # Added for potentially removing old dataset save dirs + +# Check if the specified GPU is available after setting the environment variable +if not torch.cuda.is_available(): + print("ERROR: CUDA is not available after setting CUDA_VISIBLE_DEVICES='1'. Check your PyTorch installation, GPU drivers, and that GPU 1 exists and is functional.") + # Force exit if the intended GPU is not found + sys.exit(1) +else: + # Since CUDA_VISIBLE_DEVICES is set to '1', the first *visible* device is cuda:0 + effective_device = torch.device("cuda:0") + try: + print(f"CUDA device visible to PyTorch: {torch.cuda.get_device_name(0)}") # Should show the name of GPU 1 + print(f"Script will effectively run on: {effective_device}") + # Perform a small check to ensure the device is usable + _ = torch.tensor([1.0]).to(effective_device) + print("Device check successful.") + except Exception as e: + print(f"ERROR: Failed CUDA device check for visible device 'cuda:0' (original GPU 1): {e}") + sys.exit(1) + + +# Ensure CosyVoice path is correct +COSYVOICE_PATH = '/home/chenyifu/CosyVoice' # Make sure this path is correct +if not os.path.isdir(COSYVOICE_PATH): + print(f"ERROR: CosyVoice path not found: {COSYVOICE_PATH}") + sys.exit(1) +sys.path.append(COSYVOICE_PATH) + +# Import CosyVoice *after* setting the environment variable +try: + from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2 + from cosyvoice.utils.file_utils import load_wav + print("CosyVoice imported successfully.") +except ImportError as e: + print(f"Error importing CosyVoice: {e}") + print(f"Please ensure the path '{COSYVOICE_PATH}' is correct and the library is installed within that directory.") + sys.exit(1) +except Exception as e: + print(f"An unexpected error occurred during CosyVoice import: {e}") + sys.exit(1) + +# ------------------------ +# 配置参数 (MODIFIED FOR NEW DATASET) +# ------------------------ +COMMON_VOICE_LANGUAGE = "en" # Language for prompts + +# --- !! MODIFIED !! --- +# Input: Path to the dataset created by the previous selection script +INPUT_DATASET_PATH = "./shp2_final_top20_percent/train_split_top20_percent_by_complexity" +# Output: Directory to save new audio files and the final dataset object +OUTPUT_DATASET_PATH = './shp2_top20_percent_with_query_audio' +# --- End MODIFIED --- + +SAMPLE_RATE = 16000 # Target sample rate for TTS output (should match CosyVoice default) +MAX_TTS_RETRIES = 3 +RETRY_DELAY_SECONDS = 3 # Slightly increased delay + +# ------------------------ +# 辅助函数 (GPU handling and core TTS logic - UNCHANGED as requested) +# ------------------------ +def get_random_prompt(common_voice_dataset, sample_rate=SAMPLE_RATE): + """ + 从 VoxPopuli 数据集中随机抽取一条语音及对应文本作为 prompt。 + (Logic remains unchanged) + """ + idx = random.randint(0, len(common_voice_dataset) - 1) + try: + # Use select().with_format('numpy') for potentially better memory handling with large datasets + sample = common_voice_dataset.select([idx]).with_format('numpy')[0] + audio = sample['audio'] + waveform = torch.tensor(audio['array'], dtype=torch.float32) # Created on CPU + sr = audio['sampling_rate'] + + if sr != sample_rate: + # Ensure waveform is 1D before resampling + if waveform.dim() > 1: + waveform = waveform.mean(dim=0) + if waveform.dim() != 1: + print(f"Warning: Unexpected waveform dimension {waveform.dim()} before resampling. Skipping prompt.") + return get_random_prompt(common_voice_dataset, sample_rate) # Retry + + resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate) + waveform = resampler(waveform) + + # Ensure output is 2D [1, T] + if waveform.dim() == 1: + waveform = waveform.unsqueeze(0) + elif waveform.dim() > 2: + print(f"Warning: Unexpected waveform dimension {waveform.dim()} after resampling. Skipping prompt.") + return get_random_prompt(common_voice_dataset, sample_rate) # Retry + + raw_text = sample.get('raw_text', '') + if waveform.numel() == 0 or not raw_text or not raw_text.strip(): + # print("Warning: Got an empty audio or text prompt, trying again...") + return get_random_prompt(common_voice_dataset, sample_rate) # Retry + + # Return CPU tensor, CosyVoice inference should handle moving it + return waveform, raw_text + except Exception as e: + print(f"Error getting random prompt at index {idx}: {e}. Retrying...") + time.sleep(0.1) # Small delay before retry + return get_random_prompt(common_voice_dataset, sample_rate) + +def text_to_audio(text_to_convert, cosyvoice, common_voice_dataset, stream=False, max_retries=MAX_TTS_RETRIES): + """ + 利用 CosyVoice2 模型将输入文本转换为语音,采用随机 prompt 进行零样本推理。 + Includes retry logic on failure. Assumes cosyvoice runs on the configured device. + (Logic remains unchanged) + """ + last_exception = None + prompt_speech = None + prompt_text = "N/A" + + for attempt in range(max_retries): + try: + # Get prompt - ensures it's valid this time + prompt_speech, prompt_text = get_random_prompt(common_voice_dataset, sample_rate=SAMPLE_RATE) + # prompt_speech is initially on CPU + + all_speech = [] + # cosyvoice.inference_zero_shot should internally use the GPU device it was initialized on + # (which should be the visible cuda:0, i.e., original cuda:1) + inference_generator = cosyvoice.inference_zero_shot( + text_to_convert, + prompt_text, + prompt_speech, # Pass CPU tensor + stream=stream, + text_frontend=False # Assuming default frontend is desired + ) + # Generated chunks 'tts_speech' will be on the GPU + for i, chunk in enumerate(inference_generator): + if chunk is None: + print(f"Warning: Received None chunk {i} during TTS generation for text '{text_to_convert[:60]}...'") + continue + if 'tts_speech' in chunk and chunk['tts_speech'] is not None and chunk['tts_speech'].numel() > 0: + # Ensure chunk is on the correct device (should be already, but belt-and-suspenders) + gpu_chunk = chunk['tts_speech'].to(effective_device) + all_speech.append(gpu_chunk) + # else: # Reduce log noise + # print(f"Debug: Chunk {i} missing 'tts_speech' or is empty for text '{text_to_convert[:60]}...'") + + + if not all_speech: + # Clear GPU memory cache if an error occurs during generation + if torch.cuda.is_available(): torch.cuda.empty_cache() + raise ValueError("TTS inference finished but produced no valid audio chunks.") + + # combined_speech is on GPU + combined_speech = torch.cat(all_speech, dim=-1) + sample_rate_val = cosyvoice.sample_rate + + # --- Add a check for silence --- + # Check max absolute amplitude; threshold might need tuning + if torch.max(torch.abs(combined_speech)) < 0.001: + print(f"Warning: Generated audio appears to be silent for text '{text_to_convert[:60]}...'. Retrying...") + raise ValueError("Generated audio is silent") + + + return { + # Return GPU tensor, will be moved to CPU before saving + 'audio_tensor': combined_speech, + 'sample_rate': sample_rate_val + } + + except Exception as e: + last_exception = e + print(f"Error converting text to audio on attempt {attempt + 1}/{max_retries}: {e}") + print(f" Text: '{text_to_convert[:100]}...'") + print(f" Prompt Text Used: '{prompt_text[:100]}...'") + # Clear GPU cache on error + if torch.cuda.is_available(): torch.cuda.empty_cache() + if attempt < max_retries - 1: + print(f" Retrying with a different prompt in {RETRY_DELAY_SECONDS}s...") + time.sleep(RETRY_DELAY_SECONDS) + else: + print(f" All {max_retries} TTS attempts failed.") + + print(f"Failed to generate audio for text after {max_retries} attempts: '{text_to_convert[:60]}...'") + if last_exception: + print(f"Last error: {last_exception}") + # Explicitly return None on failure + return None + +# --- !! MODIFIED process_example !! --- +def process_example(example, cosyvoice, common_voice_dataset, sample_rate=SAMPLE_RATE): + """ + 针对从磁盘加载的 *SHP-2 Top 20%* 数据集中的单个样本进行 TTS 处理。 + Processes the example['query'] field. + """ + # --- MODIFIED: Target the 'query' field --- + text_to_convert = example.get('query') + # --- End MODIFIED --- + + if not text_to_convert or not isinstance(text_to_convert, str) or text_to_convert.strip() == "": + # --- MODIFIED: Update warning message --- + print(f"Warning: Skipping example due to missing or empty 'query' field. Keys: {list(example.keys())}") + # --- End MODIFIED --- + return None + + # Call the unchanged text_to_audio function + audio_result = text_to_audio(text_to_convert, cosyvoice, common_voice_dataset, stream=False) + + if audio_result is not None: + audio_tensor = audio_result['audio_tensor'] # Still on GPU here + # Basic validation of the tensor + if audio_tensor is None or audio_tensor.numel() == 0: + print(f"Warning: TTS process returned empty tensor for query: '{text_to_convert[:60]}...'") + return None + + # Ensure correct shape (should be [1, T] from text_to_audio) + if audio_tensor.dim() == 1: + audio_tensor = audio_tensor.unsqueeze(0) + elif audio_tensor.dim() > 2: + print(f"Warning: Generated audio tensor has unexpected shape {audio_tensor.shape}. Attempting to flatten.") + audio_tensor = audio_tensor.view(1, -1) # Flatten to [1, T] + + # Double-check for emptiness after potential reshape + if audio_tensor.numel() == 0: + print(f"Warning: Generated audio tensor became empty after reshape for query: '{text_to_convert[:60]}...'") + return None + + return { + 'audio_tensor': audio_tensor, # Return GPU tensor + 'sample_rate': audio_result['sample_rate'] + } + else: + # text_to_audio already prints detailed errors + return None + +# ------------------------ +# 数据加载与模型初始化 +# ------------------------ +print("Loading VoxPopuli (as Common Voice) dataset for prompts...") +try: + # Load prompt dataset to CPU memory + common_voice = load_dataset("facebook/voxpopuli", COMMON_VOICE_LANGUAGE, split='train', trust_remote_code=True) + # Filter potentially problematic samples (optional, but can help) + common_voice = common_voice.filter(lambda x: x['audio'] is not None and x['audio']['array'] is not None and x['raw_text'] is not None and len(x['raw_text'].strip()) > 5 and len(x['audio']['array']) > SAMPLE_RATE * 0.5) # Min 0.5 sec, non-empty text + print(f"Loaded and filtered VoxPopuli '{COMMON_VOICE_LANGUAGE}' samples: {len(common_voice)}") + if len(common_voice) == 0: + raise ValueError(f"VoxPopuli dataset '{COMMON_VOICE_LANGUAGE}' loaded but contains no valid samples after filtering.") +except Exception as e: + print(f"Error loading or filtering VoxPopuli dataset: {e}") + sys.exit(1) + + +print("Initializing CosyVoice2 model...") +try: + # CosyVoice should automatically initialize on the visible device ('cuda:0', which is original 'cuda:1') + cosyvoice_model_path = os.path.join(COSYVOICE_PATH, 'pretrained_models/CosyVoice2-0.5B') + if not os.path.isdir(cosyvoice_model_path): + print(f"ERROR: CosyVoice pretrained model directory not found: {cosyvoice_model_path}") + sys.exit(1) + + cosyvoice = CosyVoice2( + cosyvoice_model_path, + load_jit=True, # Assuming JIT is preferred + load_trt=False, # Ensure TRT is False if not set up for GPU 1 + fp16=False # Keep FP16 False unless GPU 1 is known to handle it well and has enough VRAM + # device=effective_device # Usually not needed if CUDA_VISIBLE_DEVICES is set + ) + print(f"CosyVoice model initialized. Target device: {effective_device}") + # Verify model is on the correct device (optional check) + if hasattr(cosyvoice, 'model') and hasattr(cosyvoice.model, 'device'): + print(f"CosyVoice internal model device: {cosyvoice.model.device}") + elif hasattr(cosyvoice, 'device'): + print(f"CosyVoice main object device: {cosyvoice.device}") + +except Exception as e: + print(f"Error initializing CosyVoice2 model: {e}") + if isinstance(e, RuntimeError) and 'CUDA' in str(e): + print("This might be a CUDA initialization error. Ensure GPU 1 is functional, has enough memory, and required CUDA toolkit versions are compatible.") + sys.exit(1) + +# --- !! MODIFIED Dataset Loading !! --- +print(f"\nLoading the target dataset from disk: {INPUT_DATASET_PATH}") +if not os.path.exists(INPUT_DATASET_PATH): + print(f"Error: Input dataset directory not found at '{INPUT_DATASET_PATH}'.") + print("Please ensure the previous (selection) script ran successfully and produced the dataset at this location.") + sys.exit(1) + +try: + input_dataset = load_from_disk(INPUT_DATASET_PATH) + + print(f"Successfully loaded dataset with {len(input_dataset)} examples.") + if len(input_dataset) == 0: + print("Error: The loaded dataset is empty. Cannot proceed.") + sys.exit(1) + # Store original features to reconstruct the final dataset correctly + original_features = input_dataset.features + print(f"Original features: {original_features}") + # Check for 'query' column existence + if 'query' not in original_features: + print(f"Error: The loaded dataset from '{INPUT_DATASET_PATH}' does not contain the required 'query' column.") + sys.exit(1) + +except Exception as e: + print(f"Error loading dataset from '{INPUT_DATASET_PATH}': {e}") + sys.exit(1) +# --- End MODIFIED Dataset Loading --- + + +# --- Create output directories --- +os.makedirs(OUTPUT_DATASET_PATH, exist_ok=True) +# Subdirectory for the actual audio files +audio_output_dir = os.path.join(OUTPUT_DATASET_PATH, "audio_files") +os.makedirs(audio_output_dir, exist_ok=True) +print(f"Audio files will be saved in: {audio_output_dir}") +# Path for the progress tracking file +progress_file = os.path.join(OUTPUT_DATASET_PATH, "progress.txt") +print(f"Progress will be tracked in: {progress_file}") + + +# ------------------------ +# 主处理循环 (MODIFIED FOR SINGLE DATASET) +# ------------------------ +print(f"\nStarting TTS processing for {len(input_dataset)} samples...") + +start_index = 0 +# Read progress file to resume if necessary +if os.path.exists(progress_file): + try: + with open(progress_file, "r") as f: + content = f.read().strip() + if content: + start_index = int(content) + print(f"Resuming TTS processing from sample index {start_index}") + else: + print(f"Progress file '{progress_file}' is empty, starting TTS from index 0.") + start_index = 0 + except ValueError: + print(f"Could not parse integer from progress file '{progress_file}'. Starting TTS from index 0.") + start_index = 0 + except Exception as e: + print(f"Error reading progress file '{progress_file}': {e}. Starting TTS from index 0.") + start_index = 0 + +# List to hold dictionaries for the final dataset +final_samples = [] + +# --- Main Loop --- +pbar = tqdm(range(start_index, len(input_dataset)), desc=f"TTS on 'query' field", initial=start_index, total=len(input_dataset)) +for i in pbar: + sample = input_dataset[i] # Get sample dictionary (on CPU) + + # Define unique output WAV path using the index + # Using index is simple, assumes dataset order is stable during processing + output_wav_filename = f"query_{i}.wav" + output_wav_path = os.path.join(audio_output_dir, output_wav_filename) + + # --- Check if audio file already exists --- + if os.path.exists(output_wav_path): + # If already processed, create the dict for the final dataset + sample_dict = dict(sample) # Copy original data + sample_dict["query_audio_filepath"] = output_wav_path # Add the path field + final_samples.append(sample_dict) + # Update progress file even when skipping (to ensure it reflects the latest processed/checked index) + with open(progress_file, "w") as f: + f.write(str(i + 1)) + continue # Skip TTS for this sample + + # --- Perform TTS on the target device --- + # process_example handles getting the 'query' text and calling text_to_audio + result = process_example(sample, cosyvoice, common_voice, sample_rate=SAMPLE_RATE) + + if result is not None and 'audio_tensor' in result and result['audio_tensor'] is not None: + audio_tensor = result['audio_tensor'] # Received tensor is on GPU + sample_rate_val = result['sample_rate'] + + try: + # --- Move tensor to CPU before saving --- + audio_tensor_save = audio_tensor.detach().cpu().to(torch.float32) + + # Ensure shape is 2D [1, T] for torchaudio.save + if audio_tensor_save.dim() == 1: + audio_tensor_save = audio_tensor_save.unsqueeze(0) + elif audio_tensor_save.dim() > 2: + print(f"Warning: Flattening unexpected tensor shape {audio_tensor_save.shape} before saving.") + audio_tensor_save = audio_tensor_save.view(1, -1) + + # Save the audio file + torchaudio.save(output_wav_path, audio_tensor_save, sample_rate_val) + + # Create dict for the final dataset including the new path + sample_dict = dict(sample) # Copy original data + sample_dict["query_audio_filepath"] = output_wav_path # Add the path field + final_samples.append(sample_dict) + + # --- Explicitly delete GPU tensor --- + del audio_tensor + # No need to delete audio_tensor_save as it's on CPU + + except Exception as e: + print(f"Failed to save wav for sample {i} ('query' field TTS) at {output_wav_path}: {e}") + # Attempt to remove partially saved/corrupted file if save failed + if os.path.exists(output_wav_path): + try: os.remove(output_wav_path) + except OSError: pass + # Clear cache on save error too + if torch.cuda.is_available(): torch.cuda.empty_cache() + else: + # Log failure (process_example or text_to_audio already logged details) + query_text = sample.get('query', 'N/A') + print(f"Sample {i} TTS failed or produced no audio after retries (Query Text: '{query_text[:60]}...'). Audio file not saved.") + # Ensure cache is cleared even on TTS failure + if torch.cuda.is_available(): torch.cuda.empty_cache() + + + # --- Update progress file --- + # Write the index of the *next* sample to start from if resuming + with open(progress_file, "w") as f: + f.write(str(i + 1)) + + # --- Optional: Periodic cache clearing --- + if i > 0 and i % 50 == 0: # Example: clear cache every 50 iterations (adjust as needed) + if torch.cuda.is_available(): + # print(f"Clearing CUDA cache at iteration {i}...") # Debug log + torch.cuda.empty_cache() + + +# --- Final cache clear after finishing the loop --- +if torch.cuda.is_available(): + print("Clearing final CUDA cache...") + torch.cuda.empty_cache() + +# ------------------------ +# 保存最终数据集 (MODIFIED) +# ------------------------ +print("\nTTS processing loop finished.") +if final_samples: + print(f"Successfully processed (or skipped existing) {len(final_samples)} samples.") + + # --- Define features for the new dataset --- + # Start with original features and add the new audio path column + new_features_dict = original_features.copy() + new_column_name = 'query_audio_filepath' + if new_column_name in new_features_dict: + print(f"Warning: Feature '{new_column_name}' already exists in original features. Overwriting.") + new_features_dict[new_column_name] = Value('string') # Add the new column definition + try: + new_features = Features(new_features_dict) + print(f"Defined new features for saving: {new_features}") + + # --- Create the final Dataset object --- + print("Creating final Dataset object from processed samples...") + final_dataset_obj = Dataset.from_list(final_samples, features=new_features) + + # --- Define path to save the final dataset metadata object --- + # This object contains the original data + the new filepath column + final_dataset_save_path = os.path.join(OUTPUT_DATASET_PATH, "processed_dataset_with_audio") + print(f"Saving final dataset metadata (with audio paths) to: {final_dataset_save_path}...") + + # Ensure the target directory exists and is empty before saving + if os.path.exists(final_dataset_save_path): + print(f"Removing existing directory before saving: {final_dataset_save_path}") + shutil.rmtree(final_dataset_save_path) + # The save_to_disk function will create the directory + # os.makedirs(os.path.dirname(final_dataset_save_path), exist_ok=True) # Not needed if saving to the dir itself + + final_dataset_obj.save_to_disk(final_dataset_save_path) + print(f"Final dataset object saved successfully.") + + except Exception as e: + print(f"\n!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") + print(f"Error during final dataset creation or saving: {e}") + print(f"Audio files might be saved in '{audio_output_dir}', but the final dataset object could not be created/saved.") + print(f"Check the features and the content of 'final_samples'.") + print(f"!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") + +else: + print("Processing finished, but no samples were successfully processed or had existing audio files.") + print(f"Check logs for TTS errors. Audio files directory: '{audio_output_dir}'.") + + +print("\n" + "="*60) +print(f"Script finished.") +print(f"Generated audio files are located in: '{audio_output_dir}'") +print(f"The final dataset (metadata + audio file paths) is saved at: '{os.path.join(OUTPUT_DATASET_PATH, 'processed_dataset_with_audio')}' (if saving was successful)") +print("="*60) \ No newline at end of file diff --git a/r1-a/dataset/ultrachat.py b/r1-a/dataset/ultrachat.py new file mode 100644 index 0000000000000000000000000000000000000000..9fb9f9800b2772256ae574816376d0c413b3d86d --- /dev/null +++ b/r1-a/dataset/ultrachat.py @@ -0,0 +1,261 @@ +import re +import os +from datasets import load_dataset, Dataset +from tqdm.auto import tqdm +import json +import string # 引入 string 模块用于字符检查 + +# --- 可调整的过滤参数 --- +MIN_USER_QUERY_WORDS = 5 +MAX_USER_QUERY_WORDS = 150 +SIMPLE_PROMPT_PATTERNS = [ + r"^\s*(ok|yes|no|thanks?|got it|great|cool|sounds good|perfect|alright|fine|bye|goodbye)[.!\s]*$", + r"^\s*\?+\s*$", + r"^\s*i see\.?\s*$", + r"^\s*you'?re welcome\.?\s*$", # 增加一些简单回应 + r"^\s*okay then\.?\s*$", +] +CORRUPTED_ENDINGS = [" user", " assistan"] +MAX_QUERY_URLS = 0 +MAX_QUERY_NEWLINES = 3 +MIN_DIALOGUE_TURNS = 2 # 对 messages 列表的长度要求 + +# --- 新增:代码和 TTS 不友好内容过滤参数 --- +FILTER_CODE_KEYWORDS = True # 是否过滤包含常见代码关键字的查询 +CODE_KEYWORDS_PATTERN = r"\b(def|class|import|function|const|let|var|public|private|static|void|main|int|float|str|bool|return|yield|async|await|try|except|finally|if|else|for|while|switch|case|break|continue|lambda|map|filter|reduce|numpy|pandas|torch|tensorflow|react|angular|vue|console\.log|System\.out\.println)\b" # 常见编程关键字 (可扩展) + +FILTER_INLINE_CODE = True # 是否过滤包含 Markdown 行内代码 `...` 的查询 +INLINE_CODE_PATTERN = r"`[^`]+`" + +FILTER_MARKDOWN_TABLE_SEP = True # 是否过滤包含 Markdown 表格分隔符 `|---|` +MARKDOWN_TABLE_SEP_PATTERN = r"\|-+\|" + +FILTER_EXCESSIVE_SPECIAL_CHARS = True # 是否过滤特殊字符比例过高的查询 +MAX_SPECIAL_CHAR_RATIO = 0.25 # 特殊字符(非字母、数字、空格)允许的最大比例 + +FILTER_LONG_STRINGS_NO_SPACE = True # 是否过滤包含过长无空格字符串的查询 +MAX_NO_SPACE_STRING_LEN = 50 # 无空格字符串的最大允许长度 + +QUERY_FORBIDDEN_PATTERNS = [ + r"```", # 代码块标记 (已有) + # r"\|.*\|.*\|", # 简单的表格行检测 (可能过于宽泛,用下面的分隔符检测可能更好) + # 新增模式会根据上面的开关动态添加 +] + +# --- 脚本主逻辑 --- + +def is_potentially_garbled(text): + if not text or not isinstance(text, str): return True + for ending in CORRUPTED_ENDINGS: + if text.endswith(ending): return True + # 稍微放宽括号检查,只检查严重不平衡的情况 + if text.count('{') > text.count('}') + 2 or text.count('[') > text.count(']') + 2: return True + if text.count('```') % 2 != 0: return True # 未闭合的代码块 + return False + +def is_prompt_suitable(text, turn_index): # 添加 turn_index 用于调试 + """检查用户提问是否符合质量和 TTS 要求""" + if not text or not isinstance(text, str): + # print(f"DEBUG: Prompt rejected (turn {turn_index}): Not text or empty") + return False + + # --- 基本检查 --- + word_count = len(text.split()) + if not (MIN_USER_QUERY_WORDS <= word_count <= MAX_USER_QUERY_WORDS): + # print(f"DEBUG: Prompt rejected (turn {turn_index}): Word count {word_count} out of range [{MIN_USER_QUERY_WORDS}, {MAX_USER_QUERY_WORDS}]") + return False + + text_stripped = text.strip() + for pattern in SIMPLE_PROMPT_PATTERNS: + if re.fullmatch(pattern, text_stripped, re.IGNORECASE): + # print(f"DEBUG: Prompt rejected (turn {turn_index}): Matched simple pattern '{pattern}'") + return False + + if text.count('http') > MAX_QUERY_URLS: + # print(f"DEBUG: Prompt rejected (turn {turn_index}): Too many URLs") + return False + if text.count('\n') > MAX_QUERY_NEWLINES: + # print(f"DEBUG: Prompt rejected (turn {turn_index}): Too many newlines") + return False + + # --- 通用禁止模式检查 (包括原有的和动态添加的) --- + current_forbidden_patterns = list(QUERY_FORBIDDEN_PATTERNS) # 复制基础列表 + if FILTER_MARKDOWN_TABLE_SEP: + current_forbidden_patterns.append(MARKDOWN_TABLE_SEP_PATTERN) + + for pattern in current_forbidden_patterns: + # 使用 re.DOTALL 使 . 匹配换行符, re.IGNORECASE 对某些模式可能有用 (比如关键词) + search_flags = re.DOTALL + if pattern == CODE_KEYWORDS_PATTERN: # 关键词需要忽略大小写 + search_flags |= re.IGNORECASE + if re.search(pattern, text, search_flags): + # print(f"DEBUG: Prompt rejected (turn {turn_index}): Matched forbidden pattern '{pattern}'") + return False + + # --- 新增:特定代码和 TTS 不友好内容的检查 --- + + # 1. 检查常见代码关键字 + if FILTER_CODE_KEYWORDS and re.search(CODE_KEYWORDS_PATTERN, text, re.IGNORECASE): + # print(f"DEBUG: Prompt rejected (turn {turn_index}): Contains code keywords") + return False + + # 2. 检查 Markdown 行内代码 + if FILTER_INLINE_CODE and re.search(INLINE_CODE_PATTERN, text): + # print(f"DEBUG: Prompt rejected (turn {turn_index}): Contains inline code") + return False + + # 3. 检查过长的无空格字符串 (可能为哈希、base64、代码片段等) + if FILTER_LONG_STRINGS_NO_SPACE: + # \S 匹配任何非空白字符 + if re.search(r"\S{" + str(MAX_NO_SPACE_STRING_LEN) + r",}", text): + # print(f"DEBUG: Prompt rejected (turn {turn_index}): Contains long string without spaces (>{MAX_NO_SPACE_STRING_LEN})") + return False + + # 4. 检查特殊字符(非字母、数字、空格)的比例 + if FILTER_EXCESSIVE_SPECIAL_CHARS and len(text) > 0: # 避免除以零 + special_chars = 0 + total_chars = len(text) + for char in text: + # string.punctuation 包含常用标点 + # 我们也排除字母、数字和空格,剩下的算作特殊字符 + if not char.isalnum() and not char.isspace(): + special_chars += 1 + ratio = special_chars / total_chars + if ratio > MAX_SPECIAL_CHAR_RATIO: + # print(f"DEBUG: Prompt rejected (turn {turn_index}): Excessive special characters ratio ({ratio:.2f} > {MAX_SPECIAL_CHAR_RATIO})") + return False + + # --- 最终 Garbled 检查 --- + if is_potentially_garbled(text): + # print(f"DEBUG: Prompt rejected (turn {turn_index}): Potentially garbled") + return False + + # print(f"DEBUG: Prompt accepted (turn {turn_index})") + return True + +# --- format_history 函数保持不变 --- +def format_history(history_list): + """将历史消息列表格式化为文本""" + if not history_list: + return "" + formatted = [] + for msg in history_list: + role_tag = "[USER]" if msg.get('role') == 'user' else "[ASSISTANT]" + content = msg.get('content', '') + formatted.append(f"{role_tag}\n{content}") + return "\n\n".join(formatted) + + +# --- filter_ultrachat_dataset_v2 函数保持不变 (除了调用更新后的 is_prompt_suitable) --- +def filter_ultrachat_dataset_v2(dataset_name="HuggingFaceH4/ultrachat_200k", split="train_sft"): + """ + 加载并过滤 UltraChat 数据集 (根据截图修正结构访问)。 + 使用更新后的 is_prompt_suitable 进行过滤。 + """ + print(f"加载数据集: {dataset_name}, split: {split}...") + try: + dataset = load_dataset(dataset_name, split=split) + print(f"'{split}' split 加载完成。") + except Exception as e: + print(f"错误:无法加载数据集 {dataset_name} 的 '{split}' split。") + print(f"错误详情: {e}") + return [] + + filtered_samples = [] + processed_dialogues = 0 + extracted_samples = 0 + skipped_garbled_dialogue = 0 + skipped_short_dialogue = 0 + skipped_bad_format = 0 + + print(f"\n开始处理 '{split}' split 中的对话...") + for dialogue in tqdm(dataset, desc="处理对话"): + processed_dialogues += 1 + + messages = dialogue.get("messages") + prompt_id = dialogue.get("prompt_id") + initial_prompt = dialogue.get("prompt") + + if not prompt_id: continue + if not messages or not isinstance(messages, list): + skipped_bad_format += 1 + continue + if len(messages) < MIN_DIALOGUE_TURNS: + skipped_short_dialogue += 1 + continue + + dialogue_seems_garbled = False + for msg in messages: + content = msg.get("content") + # 对话级损坏检查现在仅基于 is_potentially_garbled + if is_potentially_garbled(content): + dialogue_seems_garbled = True + break + if dialogue_seems_garbled: + skipped_garbled_dialogue += 1 + continue + + current_history_list = [] + for i, message in enumerate(messages): + role = message.get("role") + content = message.get("content", "").strip() + + if not role or not content: + continue + + if role == "user": + # 调用更新后的过滤函数 + if is_prompt_suitable(content, i): + history_text = format_history(current_history_list) + filtered_samples.append({ + "dialogue_id": prompt_id, + "turn_index": i, + "query": content, + "history": history_text + }) + extracted_samples += 1 + # else: # 取消注释内部打印以查看拒绝原因 + # pass + + current_history_list.append({"role": role, "content": content}) + + print(f"\n过滤完成。") + print(f"处理对话数: {processed_dialogues}") + print(f"因格式错误跳过: {skipped_bad_format}") + print(f"因 messages 列表过短 (<{MIN_DIALOGUE_TURNS} turns) 跳过: {skipped_short_dialogue}") + print(f"因疑似损坏跳过的对话数 (基于 is_potentially_garbled): {skipped_garbled_dialogue}") + print(f"提取出的有效用户提问样本数: {extracted_samples}") + return filtered_samples + +# --- 主程序 (保持不变,但调用 V2 函数) --- +if __name__ == "__main__": + # 调用修正后的过滤函数 + filtered_data_list = filter_ultrachat_dataset_v2(dataset_name="HuggingFaceH4/ultrachat_200k", split="train_sft") + + if filtered_data_list: + filtered_dataset = Dataset.from_list(filtered_data_list) + # 更新输出目录名以反映新的过滤规则 + output_path = "./ultrachat_filtered_for_tts_preference_v3_nocode" + print(f"\n正在保存过滤后的数据集到: {output_path}") + os.makedirs(output_path, exist_ok=True) + filtered_dataset.save_to_disk(output_path) + print("数据集保存完成.") + + print("\n部分样本预览 (从保存的 Dataset 加载):") + try: + loaded_dataset = Dataset.load_from_disk(output_path) + for i in range(min(5, len(loaded_dataset))): + sample = loaded_dataset[i] + print(f"--- 样本 {i+1} (Dialogue ID: {sample['dialogue_id']}, Turn: {sample['turn_index']}) ---") + print(f"History (last 500 chars):\n...{sample['history'][-500:]}") + print(f"\nQuery: {sample['query']}") + print("-" * 20) + except Exception as e: + print(f"加载预览样本时出错: {e}") + + else: + print("\n没有找到符合条件的样本。可能原因:") + print("1. 过滤参数过于严格 (检查 MIN/MAX word counts, SIMPLE_PROMPT_PATTERNS, 新增的代码/TTS过滤参数等)。") + print("2. `is_potentially_garbled` 规则误判。") + print("3. 数据集本身在此 split 中没有符合条件的对话。") + print("4. (请检查脚本输出的跳过计数,看是哪个阶段跳过了大量样本)") \ No newline at end of file diff --git a/r1-a/dataset/ultrachat_tts.py b/r1-a/dataset/ultrachat_tts.py new file mode 100644 index 0000000000000000000000000000000000000000..2f8963532a03fdfd684c0feb9093bd0060c1df9a --- /dev/null +++ b/r1-a/dataset/ultrachat_tts.py @@ -0,0 +1,382 @@ +# --- SET CUDA DEVICE --- +# Method 1: Set environment variable BEFORE importing torch/cosyvoice +# This makes only GPU 1 visible to the script, appearing as 'cuda:0' internally. +import os +os.environ["CUDA_VISIBLE_DEVICES"] = "1" +# --- End CUDA Device Setting --- + +import random +import torch +import torchaudio +# Make sure necessary types are imported +from datasets import load_dataset, Dataset, load_from_disk, Features, Value +import sys +from tqdm import tqdm +import time +import shutil # Added for potentially removing old dataset save dirs + +# Check if the specified GPU is available after setting the environment variable +if not torch.cuda.is_available(): + print("ERROR: CUDA is not available after setting CUDA_VISIBLE_DEVICES='1'. Check your PyTorch installation, GPU drivers, and that GPU 1 exists and is functional.") + # Force exit if the intended GPU is not found + sys.exit(1) +else: + # Since CUDA_VISIBLE_DEVICES is set to '1', the first *visible* device is cuda:0 + effective_device = torch.device("cuda:0") + try: + print(f"CUDA device visible to PyTorch: {torch.cuda.get_device_name(0)}") # Should show the name of GPU 1 + print(f"Script will effectively run on: {effective_device}") + # Perform a small check to ensure the device is usable + _ = torch.tensor([1.0]).to(effective_device) + print("Device check successful.") + except Exception as e: + print(f"ERROR: Failed CUDA device check for visible device 'cuda:0' (original GPU 1): {e}") + sys.exit(1) + + +# Ensure CosyVoice path is correct +COSYVOICE_PATH = '/root/autodl-tmp/CosyVoice' # Make sure this path is correct +if not os.path.isdir(COSYVOICE_PATH): + print(f"ERROR: CosyVoice path not found: {COSYVOICE_PATH}") + sys.exit(1) +sys.path.append(COSYVOICE_PATH) + +# Import CosyVoice *after* setting the environment variable +try: + from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2 + from cosyvoice.utils.file_utils import load_wav + print("CosyVoice imported successfully.") +except ImportError as e: + print(f"Error importing CosyVoice: {e}") + print(f"Please ensure the path '{COSYVOICE_PATH}' is correct and the library is installed within that directory.") + sys.exit(1) +except Exception as e: + print(f"An unexpected error occurred during CosyVoice import: {e}") + sys.exit(1) + +# ------------------------ +# 配置参数 (MODIFIED FOR Selected UltraChat DATASET) +# ------------------------ +COMMON_VOICE_LANGUAGE = "en" # Language for prompts + +# --- !! MODIFIED !! --- +# Input: Path to the SELECTED UltraChat dataset (Top 20%) from the previous script +INPUT_DATASET_PATH = "./ultrachat_final_top20_percent/ultrachat_top20_percent_by_complexity" +# Output: Directory to save new audio files and the final dataset object for THIS specific dataset +OUTPUT_DATASET_PATH = './ultrachat_top20_percent_with_query_audio' # New distinct output path +# --- End MODIFIED --- + +SAMPLE_RATE = 16000 # Target sample rate for TTS output (should match CosyVoice default) +MAX_TTS_RETRIES = 3 +RETRY_DELAY_SECONDS = 3 + +# ------------------------ +# 辅助函数 (GPU handling and core TTS logic - UNCHANGED as requested) +# ------------------------ +def get_random_prompt(common_voice_dataset, sample_rate=SAMPLE_RATE): + """ + 从 VoxPopuli 数据集中随机抽取一条语音及对应文本作为 prompt。 + (Logic remains unchanged from previous TTS script) + """ + idx = random.randint(0, len(common_voice_dataset) - 1) + try: + sample = common_voice_dataset.select([idx]).with_format('numpy')[0] + audio = sample['audio'] + waveform = torch.tensor(audio['array'], dtype=torch.float32) # CPU + sr = audio['sampling_rate'] + if sr != sample_rate: + if waveform.dim() > 1: waveform = waveform.mean(dim=0) + if waveform.dim() != 1: return get_random_prompt(common_voice_dataset, sample_rate) + resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate) + waveform = resampler(waveform) + if waveform.dim() == 1: waveform = waveform.unsqueeze(0) + elif waveform.dim() > 2: return get_random_prompt(common_voice_dataset, sample_rate) + raw_text = sample.get('raw_text', '') + if waveform.numel() == 0 or not raw_text or not raw_text.strip(): + return get_random_prompt(common_voice_dataset, sample_rate) + return waveform, raw_text # Return CPU tensor + except Exception as e: + time.sleep(0.1) + return get_random_prompt(common_voice_dataset, sample_rate) + +def text_to_audio(text_to_convert, cosyvoice, common_voice_dataset, stream=False, max_retries=MAX_TTS_RETRIES): + """ + 利用 CosyVoice2 模型将输入文本转换为语音,采用随机 prompt 进行零样本推理。 + Includes retry logic on failure. Assumes cosyvoice runs on the configured device. + (Logic remains unchanged from previous TTS script) + """ + last_exception = None + prompt_speech, prompt_text = None, "N/A" + for attempt in range(max_retries): + try: + prompt_speech, prompt_text = get_random_prompt(common_voice_dataset, sample_rate=SAMPLE_RATE) # CPU tensor + all_speech = [] + inference_generator = cosyvoice.inference_zero_shot( + text_to_convert, prompt_text, prompt_speech, stream=stream, text_frontend=False + ) + for i, chunk in enumerate(inference_generator): # Chunks on GPU + if chunk is None: continue + if 'tts_speech' in chunk and chunk['tts_speech'] is not None and chunk['tts_speech'].numel() > 0: + gpu_chunk = chunk['tts_speech'].to(effective_device) + all_speech.append(gpu_chunk) + if not all_speech: + if torch.cuda.is_available(): torch.cuda.empty_cache() + raise ValueError("TTS inference finished but produced no valid audio chunks.") + combined_speech = torch.cat(all_speech, dim=-1) # GPU tensor + sample_rate_val = cosyvoice.sample_rate + if torch.max(torch.abs(combined_speech)) < 0.001: + raise ValueError("Generated audio is silent") + return {'audio_tensor': combined_speech, 'sample_rate': sample_rate_val} # Return GPU tensor + except Exception as e: + last_exception = e + print(f"Error converting text to audio on attempt {attempt + 1}/{max_retries}: {e}") + print(f" Text: '{text_to_convert[:100]}...'") + # print(f" Prompt Text Used: '{prompt_text[:100]}...'") # Reduce log noise + if torch.cuda.is_available(): torch.cuda.empty_cache() + if attempt < max_retries - 1: + print(f" Retrying with a different prompt in {RETRY_DELAY_SECONDS}s...") + time.sleep(RETRY_DELAY_SECONDS) + else: + print(f" All {max_retries} TTS attempts failed.") + return None + +# --- PROCESS EXAMPLE (Targets 'query' field) --- +def process_example(example, cosyvoice, common_voice_dataset, sample_rate=SAMPLE_RATE): + """ + 针对从磁盘加载的 *Selected Top 20% UltraChat* 数据集中的单个样本进行 TTS 处理。 + Processes the example['query'] field. + """ + text_to_convert = example.get('query') + # Get identifiers for logging, if they exist in this dataset version + dialogue_id = example.get('dialogue_id', 'N/A') + turn_index = example.get('turn_index', 'N/A') # May not be present if not carried over + + if not text_to_convert or not isinstance(text_to_convert, str) or not text_to_convert.strip(): + print(f"Warning: Skipping example (ID: {dialogue_id}, Turn: {turn_index}) due to missing or empty 'query' field.") + return None + + # Call the unchanged text_to_audio function + audio_result = text_to_audio(text_to_convert, cosyvoice, common_voice_dataset, stream=False) + + if audio_result is not None: + audio_tensor = audio_result['audio_tensor'] # Still on GPU here + if audio_tensor is None or audio_tensor.numel() == 0: + print(f"Warning: TTS process returned empty tensor for query (ID: {dialogue_id}, Turn: {turn_index}): '{text_to_convert[:60]}...'") + return None + if audio_tensor.dim() == 1: audio_tensor = audio_tensor.unsqueeze(0) + elif audio_tensor.dim() > 2: + print(f"Warning: Generated audio tensor unexpected shape {audio_tensor.shape} (ID: {dialogue_id}, Turn: {turn_index}). Flattening.") + audio_tensor = audio_tensor.view(1, -1) # Flatten to [1, T] + if audio_tensor.numel() == 0: + print(f"Warning: Generated audio tensor became empty after reshape for query (ID: {dialogue_id}, Turn: {turn_index}): '{text_to_convert[:60]}...'") + return None + return { + 'audio_tensor': audio_tensor, # Return GPU tensor + 'sample_rate': audio_result['sample_rate'] + } + else: + return None # Errors logged within text_to_audio + +# ------------------------ +# 数据加载与模型初始化 (Model and Prompt Dataset Loading Unchanged) +# ------------------------ +print("Loading VoxPopuli (as Common Voice) dataset for prompts...") +try: + common_voice = load_dataset("facebook/voxpopuli", COMMON_VOICE_LANGUAGE, split='train', trust_remote_code=True) + common_voice = common_voice.filter(lambda x: x['audio'] is not None and x['audio']['array'] is not None and x['raw_text'] is not None and len(x['raw_text'].strip()) > 5 and len(x['audio']['array']) > SAMPLE_RATE * 0.5) + print(f"Loaded and filtered VoxPopuli '{COMMON_VOICE_LANGUAGE}' samples: {len(common_voice)}") + if len(common_voice) == 0: raise ValueError(f"VoxPopuli '{COMMON_VOICE_LANGUAGE}' loaded but no valid samples after filtering.") +except Exception as e: + print(f"Error loading or filtering VoxPopuli dataset: {e}") + sys.exit(1) + +print("Initializing CosyVoice2 model...") +try: + # CosyVoice initialization remains the same + cosyvoice_model_path = os.path.join(COSYVOICE_PATH, 'pretrained_models/CosyVoice2-0.5B') + if not os.path.isdir(cosyvoice_model_path): raise FileNotFoundError(f"CosyVoice pretrained model directory not found: {cosyvoice_model_path}") + cosyvoice = CosyVoice2( + cosyvoice_model_path, load_jit=True, load_trt=False, fp16=False + ) + print(f"CosyVoice model initialized. Target device: {effective_device}") +except Exception as e: + print(f"Error initializing CosyVoice2 model: {e}") + if isinstance(e, RuntimeError) and 'CUDA' in str(e): print("CUDA initialization error? Check GPU 1 status/memory.") + sys.exit(1) + +# --- !! MODIFIED Selected UltraChat Dataset Loading !! --- +print(f"\nLoading the target Selected UltraChat (Top 20%) dataset from disk: {INPUT_DATASET_PATH}") +if not os.path.exists(INPUT_DATASET_PATH): + print(f"Error: Input dataset directory not found at '{INPUT_DATASET_PATH}'.") + print("Please ensure the UltraChat Selection script ran successfully and produced the dataset at this location.") + sys.exit(1) + +try: + input_dataset = load_from_disk(INPUT_DATASET_PATH) + + print(f"Successfully loaded Selected UltraChat dataset with {len(input_dataset)} examples.") + if len(input_dataset) == 0: + print("Error: The loaded dataset is empty. Cannot proceed.") + sys.exit(1) + # Store original features to reconstruct the final dataset correctly + original_features = input_dataset.features + print(f"Original features: {original_features}") + # Check for 'query' column existence (essential for TTS) + if 'query' not in original_features: + print(f"Error: The loaded dataset from '{INPUT_DATASET_PATH}' does not contain the required 'query' column.") + sys.exit(1) + +except Exception as e: + print(f"Error loading dataset from '{INPUT_DATASET_PATH}': {e}") + sys.exit(1) +# --- End MODIFIED Dataset Loading --- + + +# --- Create output directories --- +os.makedirs(OUTPUT_DATASET_PATH, exist_ok=True) +audio_output_dir = os.path.join(OUTPUT_DATASET_PATH, "audio_files") +os.makedirs(audio_output_dir, exist_ok=True) +print(f"Audio files will be saved in: {audio_output_dir}") +progress_file = os.path.join(OUTPUT_DATASET_PATH, "progress.txt") +print(f"Progress will be tracked in: {progress_file}") + + +# ------------------------ +# 主处理循环 (MODIFIED FOR SINGLE Selected UltraChat DATASET) +# ------------------------ +# --- !! MODIFIED: Update log message !! --- +print(f"\nStarting TTS processing for {len(input_dataset)} Selected UltraChat (Top 20%) samples...") + +start_index = 0 +# Read progress file to resume if necessary +if os.path.exists(progress_file): + try: + with open(progress_file, "r") as f: + content = f.read().strip() + if content: start_index = int(content) + print(f"Resuming TTS processing from sample index {start_index}") + except Exception as e: + print(f"Error reading progress file '{progress_file}': {e}. Starting TTS from index 0.") + start_index = 0 + +# List to hold dictionaries for the final dataset +final_samples = [] + +# --- Main Loop --- +# --- !! MODIFIED: Update progress bar description !! --- +pbar = tqdm(range(start_index, len(input_dataset)), desc=f"TTS on Selected UltraChat 'query'", initial=start_index, total=len(input_dataset)) +for i in pbar: + sample = input_dataset[i] # Get sample dictionary (on CPU) + + # Define unique output WAV path using the index + output_wav_filename = f"query_{i}.wav" + output_wav_path = os.path.join(audio_output_dir, output_wav_filename) + + # --- Check if audio file already exists --- + if os.path.exists(output_wav_path): + sample_dict = dict(sample) + sample_dict["query_audio_filepath"] = output_wav_path # Add path field + final_samples.append(sample_dict) + with open(progress_file, "w") as f: f.write(str(i + 1)) + continue # Skip TTS + + # --- Perform TTS on the target device --- + result = process_example(sample, cosyvoice, common_voice, sample_rate=SAMPLE_RATE) + + if result is not None and 'audio_tensor' in result and result['audio_tensor'] is not None: + audio_tensor = result['audio_tensor'] # GPU tensor + sample_rate_val = result['sample_rate'] + try: + # Move tensor to CPU before saving + audio_tensor_save = audio_tensor.detach().cpu().to(torch.float32) + if audio_tensor_save.dim() == 1: audio_tensor_save = audio_tensor_save.unsqueeze(0) + elif audio_tensor_save.dim() > 2: audio_tensor_save = audio_tensor_save.view(1, -1) + + torchaudio.save(output_wav_path, audio_tensor_save, sample_rate_val) + + # Create dict for the final dataset + sample_dict = dict(sample) + sample_dict["query_audio_filepath"] = output_wav_path # Add path field + final_samples.append(sample_dict) + + del audio_tensor # Delete GPU tensor + + except Exception as e: + # Log error with identifiers if available + dialogue_id = sample.get('dialogue_id', 'N/A') + turn_index = sample.get('turn_index', 'N/A') + print(f"Failed to save wav for sample {i} (ID: {dialogue_id}, Turn: {turn_index}) at {output_wav_path}: {e}") + if os.path.exists(output_wav_path): + try: os.remove(output_wav_path) + except OSError: pass + if torch.cuda.is_available(): torch.cuda.empty_cache() + else: + # Failure logged in process_example/text_to_audio + if torch.cuda.is_available(): torch.cuda.empty_cache() + + # --- Update progress file --- + with open(progress_file, "w") as f: f.write(str(i + 1)) + + # --- Optional: Periodic cache clearing --- + if i > 0 and i % 50 == 0: + if torch.cuda.is_available(): torch.cuda.empty_cache() + + +# --- Final cache clear after finishing the loop --- +if torch.cuda.is_available(): + print("Clearing final CUDA cache...") + torch.cuda.empty_cache() + +# ------------------------ +# 保存最终数据集 (MODIFIED FOR Selected UltraChat) +# ------------------------ +print("\nTTS processing loop finished.") +if final_samples: + # --- !! MODIFIED: Update log message !! --- + print(f"Successfully processed (or skipped existing) {len(final_samples)} Selected UltraChat (Top 20%) samples.") + + # --- Define features for the new dataset --- + new_features_dict = original_features.copy() + new_column_name = 'query_audio_filepath' # Name of the new column + if new_column_name in new_features_dict: + print(f"Warning: Feature '{new_column_name}' already exists in original features. Overwriting.") + new_features_dict[new_column_name] = Value('string') # Add the new column definition + try: + new_features = Features(new_features_dict) + print(f"Defined new features for saving: {new_features}") + + # --- Create the final Dataset object --- + print("Creating final Dataset object from processed samples...") + final_dataset_obj = Dataset.from_list(final_samples, features=new_features) + + # --- Define path to save the final dataset metadata object --- + final_dataset_save_path = os.path.join(OUTPUT_DATASET_PATH, "processed_dataset_with_audio") + # --- !! MODIFIED: Update log message !! --- + print(f"Saving final Selected UltraChat (Top 20%) dataset metadata (with audio paths) to: {final_dataset_save_path}...") + + # Ensure the target directory exists and is empty before saving + if os.path.exists(final_dataset_save_path): + print(f"Removing existing directory before saving: {final_dataset_save_path}") + shutil.rmtree(final_dataset_save_path) + + final_dataset_obj.save_to_disk(final_dataset_save_path) + print(f"Final dataset object saved successfully.") + + except Exception as e: + print(f"\n!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") + print(f"Error during final dataset creation or saving: {e}") + print(f"Audio files might be saved in '{audio_output_dir}', but the final dataset object could not be created/saved.") + print(f"Check the features and the content of 'final_samples'.") + print(f"!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") + +else: + print("Processing finished, but no samples were successfully processed or had existing audio files.") + print(f"Check logs for TTS errors. Audio files directory: '{audio_output_dir}'.") + + +print("\n" + "="*60) +# --- !! MODIFIED: Update final log messages !! --- +print(f"Script finished for Selected UltraChat (Top 20%) dataset.") +print(f"Generated audio files are located in: '{audio_output_dir}'") +print(f"The final dataset (metadata + audio file paths) is saved at: '{os.path.join(OUTPUT_DATASET_PATH, 'processed_dataset_with_audio')}' (if saving was successful)") +print("="*60) \ No newline at end of file diff --git a/r1-a/prompt_only_examine.py b/r1-a/prompt_only_examine.py new file mode 100644 index 0000000000000000000000000000000000000000..e511dceb0ab1dc92a9fa087408758afa8b5c9dac --- /dev/null +++ b/r1-a/prompt_only_examine.py @@ -0,0 +1,48 @@ +from datasets import load_from_disk +import torchaudio +import os + +# IMPORTANT: When you load and use this dataset, your CWD should have the same +# relationship to the audio files as it did when the dataset was created, +# OR you need to manually resolve the paths. + +# Path where the HF dataset was saved +dataset_path = '/root/autodl-tmp/audio-r1/r1-a/dataset/prompt_only_fully_merged_with_audio/final_hf_dataset_relative_paths_v4' +ds = load_from_disk(dataset_path) +breakpoint() # Use this to inspect the dataset structure if needed +# Get a relative path string from the dataset +relative_path_from_dataset = ds[0]['question_audio_relative_path'] # Or your field name +print(f"Relative path from dataset: {relative_path_from_dataset}") + +# To load the audio, this relative path needs to resolve correctly from your *current* CWD +# Option 1: If your CWD is correct +# current_cwd_when_loading = os.getcwd() +# print(f"Current CWD when loading: {current_cwd_when_loading}") +# full_path_to_audio = os.path.abspath(relative_path_from_dataset) # os.path.abspath resolves based on CWD + +# Option 2: If you know the dataset's "root" directory relative to which paths were made +# This is safer if you move the dataset and audio files together. +# Assume the dataset was created when CWD was '/root/autodl-tmp/audio-r1/r1-a/dataset/' +# And now you are running this loading script from somewhere else, but you know that 'root'. +# For example, if your audio files are now located such that the relative path still makes sense +# if prepended by a new base_dir. + +# Example: If you know the original CWD when the dataset was created, +# and you want to reconstruct the absolute path assuming the audio files haven't moved +# This is generally what os.path.join(original_cwd, relative_path) would give if files are static +# However, if you've moved the dataset AND audio files together, keeping their relative structure, +# then the relative_path_from_dataset should resolve correctly if your CWD is the new "root" of that structure. + +# The simplest way if you are in the correct CWD when loading: +full_path_to_audio = os.path.join(os.getcwd(), relative_path_from_dataset) # This might not be right if rel path has '..' +full_path_to_audio = os.path.abspath(relative_path_from_dataset) # This is usually what you want if CWD is the intended base + +print(f"Attempting to load from (resolved path): {full_path_to_audio}") + +if os.path.exists(full_path_to_audio): + waveform, sample_rate = torchaudio.load(full_path_to_audio) + print(f"Loaded audio: waveform shape {waveform.shape}, sample rate {sample_rate}") +else: + print(f"Audio file NOT FOUND at resolved path: {full_path_to_audio}") + print("Ensure your Current Working Directory is set correctly so the relative path can be resolved.") + print(f"Alternatively, manually construct the absolute path if you know where the audio files are relative to a fixed base.") \ No newline at end of file diff --git a/r1-a/train.py b/r1-a/train.py new file mode 100644 index 0000000000000000000000000000000000000000..78b71a413932c5ea5d72c025c6f73d34fea76c9d --- /dev/null +++ b/r1-a/train.py @@ -0,0 +1,145 @@ +import torch +import multiprocessing +import time +import os +import sys + +def occupy_gpu(device_id, memory_fraction=0.90, compute_size=8192): + """ + Target function for a process to occupy a specific GPU. + Args: + device_id (int): The ID of the GPU to occupy (e.g., 1 for cuda:1). + memory_fraction (float): Fraction of free memory to try and allocate (0.0 to 1.0). + compute_size (int): Dimension of square matrices for matmul compute load. + Larger values increase compute intensity but also use some memory. + """ + try: + # Ensure this process targets the correct GPU + torch.cuda.set_device(device_id) + device = f'cuda:{device_id}' + process_id = os.getpid() + print(f"[PID {process_id}] Targeting {device}...") + + # --- 1. Allocate Memory --- + allocated_tensor = None + try: + # Get free memory and total memory + free_mem, total_mem = torch.cuda.mem_get_info(device_id) + target_alloc_bytes = int(free_mem * memory_fraction) + print(f"[PID {process_id}] {device}: Total Mem={total_mem/1024**3:.2f} GB, Free Mem={free_mem/1024**3:.2f} GB") + print(f"[PID {process_id}] {device}: Attempting to allocate ~{target_alloc_bytes/1024**3:.2f} GB ({memory_fraction*100:.0f}% of free)...") + + # Calculate tensor size (using float32 = 4 bytes per element) + elements_needed = target_alloc_bytes // 4 + # Create a 1D tensor first, as it's simpler to calculate size + allocated_tensor = torch.empty(elements_needed, dtype=torch.float32, device=device) + # Fill it with some data to ensure allocation happens (sometimes lazy allocation occurs) + allocated_tensor.fill_(1.0) + torch.cuda.synchronize(device_id) # Wait for allocation to complete + + # Verify allocated memory (this is approximate as PyTorch reserves some overhead) + allocated_bytes = allocated_tensor.nelement() * allocated_tensor.element_size() + print(f"[PID {process_id}] {device}: Successfully allocated tensor using ~{allocated_bytes/1024**3:.2f} GB.") + # Keep the tensor alive by referencing it + + except RuntimeError as e: + print(f"[PID {process_id}] {device}: ERROR allocating memory - {e}. Memory usage might be lower.") + print(f"[PID {process_id}] {device}: Check if {memory_fraction*100:.0f}% is too high or other processes are using memory.") + # Continue to compute loop even if memory allocation failed partially or fully + + # --- 2. Run Compute Load --- + print(f"[PID {process_id}] {device}: Starting compute loop (matmul {compute_size}x{compute_size})...") + # Create tensors for computation + try: + a = torch.randn(compute_size, compute_size, dtype=torch.float32, device=device) + b = torch.randn(compute_size, compute_size, dtype=torch.float32, device=device) + except RuntimeError as e: + print(f"[PID {process_id}] {device}: ERROR creating compute tensors ({compute_size}x{compute_size}) - {e}.") + print(f"[PID {process_id}] {device}: GPU might not have enough remaining memory for this compute size. Try reducing 'compute_size'. Exiting process.") + return # Exit this process if we can't even create compute tensors + + # Infinite compute loop + while True: + # Perform a compute-intensive operation + c = torch.matmul(a, b) + # Optional: add more operations if matmul alone isn't maxing out utilization + # a = a * 1.0001 # Avoid values growing too large/small quickly + # b = b + 0.0001 + # torch.cuda.synchronize(device_id) # Usually not needed in a tight loop like this + + # We don't need to do anything with 'c', the goal is just the computation. + # No sleep here, we want maximum utilization. + + except Exception as e: + print(f"[PID {process_id}] {device}: UNEXPECTED ERROR - {e}") + # Log any other errors that might occur + +if __name__ == "__main__": + # --- Configuration --- + TARGET_GPU_IDS = [0,1] # <<< Your target GPU IDs here (cuda:1, cuda:2, cuda:3) + MEMORY_FRACTION_TO_USE = 0.85 # <<< Try to use 90% of *free* memory. Adjust if needed (0.8 to 0.95 is typical) + COMPUTE_MATRIX_DIM = 8192 # <<< Dimension for matmul (e.g., 8192, 10240, 12288). + # Larger = more compute intensive bursts, but uses more temp memory. + # Adjust based on GPU capability and remaining memory after allocation. + # --- End Configuration --- + + # Check CUDA availability and device count + if not torch.cuda.is_available(): + print("Error: CUDA is not available. Please check your PyTorch installation and CUDA drivers.") + sys.exit(1) + + num_gpus = torch.cuda.device_count() + print(f"Found {num_gpus} CUDA devices.") + + valid_target_gpus = [] + for gpu_id in TARGET_GPU_IDS: + if gpu_id < 0 or gpu_id >= num_gpus: + print(f"Warning: GPU ID {gpu_id} is invalid (must be between 0 and {num_gpus-1}). Skipping.") + else: + valid_target_gpus.append(gpu_id) + + if not valid_target_gpus: + print("Error: No valid target GPUs specified or available. Exiting.") + sys.exit(1) + + print(f"Attempting to occupy GPUs: {valid_target_gpus}") + print(f"Memory target: {MEMORY_FRACTION_TO_USE*100:.0f}% of free memory per GPU.") + print(f"Compute load: Matrix multiplication of size {COMPUTE_MATRIX_DIM}x{COMPUTE_MATRIX_DIM}.") + print("-" * 30) + + # Set multiprocessing start method (important for CUDA in some environments) + try: + multiprocessing.set_start_method('spawn', force=True) + except RuntimeError: + print("Note: Could not set multiprocessing start method to 'spawn'. Using default.") + pass + + processes = [] + for gpu_id in valid_target_gpus: + p = multiprocessing.Process(target=occupy_gpu, args=(gpu_id, MEMORY_FRACTION_TO_USE, COMPUTE_MATRIX_DIM)) + processes.append(p) + p.start() + + print("\nProcesses started. Monitor GPU usage with 'nvidia-smi'.") + print("Press Ctrl+C to stop the script and terminate processes.") + + try: + # Keep the main script alive while child processes run + for p in processes: + p.join() # Wait for processes to finish (they won't unless error or terminated) + except KeyboardInterrupt: + print("\nCtrl+C detected. Terminating GPU occupation processes...") + for p in processes: + if p.is_alive(): + p.terminate() # Send SIGTERM + p.join(timeout=5) # Wait max 5 seconds for graceful exit + if p.is_alive(): + print(f"Process {p.pid} did not terminate gracefully, killing.") + p.kill() # Send SIGKILL if necessary + p.join() # Wait for kill + print("All processes terminated.") + except Exception as main_e: + print(f"An error occurred in the main process: {main_e}") + # Optionally try to clean up child processes here too + for p in processes: + if p.is_alive(): p.terminate() \ No newline at end of file