| import os |
| import random |
| import torch |
| import torchaudio |
| |
| from datasets import load_dataset, Dataset, load_from_disk |
| import sys |
| from tqdm import tqdm |
| import time |
|
|
| sys.path.append('/root/autodl-tmp/CosyVoice') |
| from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2 |
| from cosyvoice.utils.file_utils import load_wav |
|
|
| |
| |
| |
| COMMON_VOICE_LANGUAGE = "en" |
| |
| FILTERED_DATASET_PATH = "pku_saferlhf_filtered_unsafe_diverse_hf" |
| |
| OUTPUT_DATASET_PATH = './pku_saferlhf_filtered_with_audio' |
| SAMPLE_RATE = 16000 |
| MAX_TTS_RETRIES = 3 |
| RETRY_DELAY_SECONDS = 2 |
|
|
| |
| |
| |
| def get_random_prompt(common_voice_dataset, sample_rate=SAMPLE_RATE): |
| """ |
| 从 VoxPopuli (替代原 Common Voice) 数据集中随机抽取一条语音及对应文本作为 prompt。 |
| """ |
| idx = random.randint(0, len(common_voice_dataset) - 1) |
| sample = common_voice_dataset.select([idx])[0] |
| audio = sample['audio'] |
| waveform = torch.tensor(audio['array'], dtype=torch.float32) |
| sr = audio['sampling_rate'] |
| if sr != sample_rate: |
| if waveform.dim() > 1: |
| waveform = waveform.mean(dim=0) |
| resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate) |
| waveform = resampler(waveform) |
| if waveform.dim() == 1: |
| waveform = waveform.unsqueeze(0) |
| if waveform.numel() == 0 or not sample['raw_text']: |
| print("Warning: Got an empty prompt, trying again...") |
| return get_random_prompt(common_voice_dataset, sample_rate) |
| return waveform, sample['raw_text'] |
|
|
| def text_to_audio(query_text, cosyvoice, common_voice_dataset, stream=False, max_retries=MAX_TTS_RETRIES): |
| """ |
| 利用 CosyVoice2 模型将输入文本转换为语音,采用随机 prompt 进行零样本推理。 |
| Includes retry logic on failure. |
| """ |
| last_exception = None |
| for attempt in range(max_retries): |
| try: |
| prompt_speech, prompt_text = get_random_prompt(common_voice_dataset, sample_rate=SAMPLE_RATE) |
|
|
| all_speech = [] |
| inference_generator = cosyvoice.inference_zero_shot( |
| query_text, |
| prompt_text, |
| prompt_speech, |
| stream=stream, |
| text_frontend=False |
| ) |
| for i, chunk in enumerate(inference_generator): |
| if 'tts_speech' in chunk and chunk['tts_speech'] is not None: |
| all_speech.append(chunk['tts_speech']) |
| else: |
| print(f"Warning: Chunk {i} missing 'tts_speech' for text '{query_text[:60]}...'") |
|
|
| if not all_speech: |
| raise ValueError("TTS inference finished but produced no audio chunks.") |
|
|
| combined_speech = torch.cat(all_speech, dim=-1) |
| sample_rate_val = cosyvoice.sample_rate |
|
|
| return { |
| 'audio_tensor': combined_speech, |
| 'sample_rate': sample_rate_val |
| } |
| except Exception as e: |
| last_exception = e |
| print(f"Error converting text to audio on attempt {attempt + 1}/{max_retries}: {e}") |
| print(f"Text: '{query_text[:100]}...'") |
| print(f"Prompt Text: '{prompt_text[:100]}...'") |
| if attempt < max_retries - 1: |
| print(f"Retrying with a different prompt in {RETRY_DELAY_SECONDS}s...") |
| time.sleep(RETRY_DELAY_SECONDS) |
| else: |
| print(f"All {max_retries} TTS attempts failed.") |
|
|
| print(f"Failed to generate audio for text after {max_retries} attempts: '{query_text[:60]}...'") |
| print(f"Last error: {last_exception}") |
| return None |
|
|
| def process_example(example, cosyvoice, common_voice_dataset, sample_rate=SAMPLE_RATE): |
| """ |
| 针对从磁盘加载的 PKU-SafeRLHF 过滤后数据集中的单个样本进行 TTS 处理。 |
| Processes example['prompt']. <--- Changed from 'query'/'question' |
| """ |
| |
| query = example.get('prompt') |
| if not query or not isinstance(query, str) or query.strip() == "": |
| print(f"Warning: Skipping example due to missing or empty 'prompt' field: {example.keys()}") |
| return None |
|
|
| |
| audio_result = text_to_audio(query, cosyvoice, common_voice_dataset, stream=False) |
|
|
| if audio_result is not None: |
| audio_tensor = audio_result['audio_tensor'] |
| if audio_tensor.dim() == 1: |
| audio_tensor = audio_tensor.unsqueeze(0) |
| elif audio_tensor.dim() > 2: |
| print(f"Warning: Generated audio tensor has unexpected shape {audio_tensor.shape}. Attempting to flatten.") |
| audio_tensor = audio_tensor.view(1, -1) |
|
|
| if audio_tensor.numel() == 0: |
| print(f"Warning: Generated audio tensor is empty for prompt: '{query[:60]}...'") |
| return None |
|
|
| return { |
| 'audio_tensor': audio_tensor, |
| 'sample_rate': audio_result['sample_rate'] |
| } |
| else: |
| return None |
|
|
| |
| |
| |
| print("Loading VoxPopuli (as Common Voice) dataset for prompts...") |
| try: |
| common_voice = load_dataset("facebook/voxpopuli", "en", split='train') |
| print(f"Total VoxPopuli {COMMON_VOICE_LANGUAGE} samples: {len(common_voice)}") |
| if len(common_voice) == 0: |
| raise ValueError("VoxPopuli dataset loaded but contains no samples.") |
| except Exception as e: |
| print(f"Error loading VoxPopuli dataset: {e}") |
| sys.exit(1) |
|
|
|
|
| print("Initializing CosyVoice2 model...") |
| try: |
| cosyvoice = CosyVoice2( |
| '/root/autodl-tmp/CosyVoice/pretrained_models/CosyVoice2-0.5B', |
| load_jit=True, |
| load_trt=False, |
| fp16=False |
| ) |
| except Exception as e: |
| print(f"Error initializing CosyVoice2 model: {e}") |
| sys.exit(1) |
|
|
| print(f"Loading pre-filtered PKU-SafeRLHF dataset from disk: {FILTERED_DATASET_PATH}") |
| try: |
| |
| filtered_dataset = load_from_disk(FILTERED_DATASET_PATH) |
| if not filtered_dataset: |
| raise ValueError(f"Dataset loaded from '{FILTERED_DATASET_PATH}' is empty or invalid.") |
| print(f"Successfully loaded dataset with {len(filtered_dataset)} examples.") |
| |
| |
| dataset_dict = {"train": filtered_dataset} |
| |
| except FileNotFoundError: |
| print(f"Error: Pre-filtered dataset not found at '{FILTERED_DATASET_PATH}'.") |
| print("Please ensure the first script ran successfully and saved the data to the correct location.") |
| sys.exit(1) |
| except Exception as e: |
| print(f"Error loading pre-filtered dataset from '{FILTERED_DATASET_PATH}': {e}") |
| sys.exit(1) |
|
|
| |
| os.makedirs(OUTPUT_DATASET_PATH, exist_ok=True) |
|
|
| |
| |
| |
| final_dataset_dict = {} |
|
|
| |
| for split_name, split_dataset in dataset_dict.items(): |
| print(f"Processing split: {split_name} with {len(split_dataset)} examples") |
| split_output_dir = os.path.join(OUTPUT_DATASET_PATH, split_name) |
| os.makedirs(split_output_dir, exist_ok=True) |
|
|
| |
| progress_file = os.path.join(split_output_dir, "progress.txt") |
| start_index = 0 |
| if os.path.exists(progress_file): |
| try: |
| with open(progress_file, "r") as f: |
| content = f.read().strip() |
| if content: |
| start_index = int(content) |
| print(f"Resuming split '{split_name}' from sample index {start_index}") |
| else: |
| print(f"Progress file '{progress_file}' is empty, starting from index 0.") |
| start_index = 0 |
| except ValueError: |
| print(f"Could not parse integer from progress file '{progress_file}'. Starting from index 0.") |
| start_index = 0 |
| except Exception as e: |
| print(f"Error reading progress file '{progress_file}': {e}. Starting from index 0.") |
| start_index = 0 |
|
|
| final_samples = [] |
|
|
| |
| pbar = tqdm(range(start_index, len(split_dataset)), desc=f"Processing {split_name}", initial=start_index, total=len(split_dataset)) |
| for i in pbar: |
| sample = split_dataset[i] |
| output_wav_path = os.path.join(split_output_dir, f"{i}.wav") |
|
|
| if os.path.exists(output_wav_path): |
| sample_dict = {k: sample[k] for k in sample.keys()} |
| sample_dict["audio_filepath"] = output_wav_path |
| final_samples.append(sample_dict) |
| with open(progress_file, "w") as f: |
| f.write(str(i + 1)) |
| continue |
|
|
| |
| result = process_example(sample, cosyvoice, common_voice, sample_rate=SAMPLE_RATE) |
|
|
| if result is not None: |
| audio_tensor = result['audio_tensor'] |
| sample_rate_val = result['sample_rate'] |
|
|
| try: |
| audio_tensor_save = audio_tensor.detach().cpu().to(torch.float32) |
| if audio_tensor_save.dim() == 1: |
| audio_tensor_save = audio_tensor_save.unsqueeze(0) |
| elif audio_tensor_save.dim() > 2: |
| audio_tensor_save = audio_tensor_save.view(1, -1) |
|
|
| torchaudio.save(output_wav_path, audio_tensor_save, sample_rate_val) |
|
|
| |
| sample_dict = {k: sample[k] for k in sample.keys()} |
| sample_dict["audio_filepath"] = output_wav_path |
| final_samples.append(sample_dict) |
|
|
| except Exception as e: |
| print(f"Failed to save wav for sample {i} at {output_wav_path}: {e}") |
| else: |
| print(f"Sample {i} processing failed after retries (Prompt: '{sample.get('prompt', 'N/A')[:60]}...'), no audio generated.") |
|
|
| |
| with open(progress_file, "w") as f: |
| f.write(str(i + 1)) |
|
|
| |
| if final_samples: |
| final_dataset_obj = Dataset.from_list(final_samples) |
| final_dataset_save_path = os.path.join(split_output_dir, "final_dataset") |
| try: |
| print(f"Saving final dataset for split '{split_name}' to {final_dataset_save_path}...") |
| final_dataset_obj.save_to_disk(final_dataset_save_path) |
| print(f"Finished processing split: {split_name}. Saved {len(final_samples)} samples with audio paths.") |
| final_dataset_dict[split_name] = final_dataset_obj |
| except Exception as e: |
| print(f"Error saving final dataset for split '{split_name}' to disk: {e}") |
| else: |
| print(f"Finished processing split: {split_name}. No samples were successfully processed or saved.") |
|
|
|
|
| print("="*30) |
| if final_dataset_dict: |
| print(f"All specified splits processed. Final datasets saved in respective subdirectories within '{OUTPUT_DATASET_PATH}'.") |
| print(f"Processed splits: {list(final_dataset_dict.keys())}") |
| else: |
| print(f"Processing finished, but no final datasets were generated or saved in '{OUTPUT_DATASET_PATH}'. Check logs for errors.") |
| print("="*30) |