grsdfdf / r1-a /dataset /pkusafe_tts.py
1f's picture
Add files using upload-large-folder tool
19891ba verified
import os
import random
import torch
import torchaudio
# Import load_from_disk to load the dataset saved by your first script
from datasets import load_dataset, Dataset, load_from_disk
import sys
from tqdm import tqdm
import time # Import time for potential delays between retries
sys.path.append('/root/autodl-tmp/CosyVoice') # Make sure this path is correct for your environment
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
from cosyvoice.utils.file_utils import load_wav
# ------------------------
# 配置参数
# ------------------------
COMMON_VOICE_LANGUAGE = "en"
# --- Path to the pre-filtered dataset saved by your FIRST script ---
FILTERED_DATASET_PATH = "pku_saferlhf_filtered_unsafe_diverse_hf" # <-- IMPORTANT: Make sure this matches the OUTPUT_DATASET_DIR from your first script
# --- Output path for THIS TTS script ---
OUTPUT_DATASET_PATH = './pku_saferlhf_filtered_with_audio' # <-- New output path for the dataset with audio
SAMPLE_RATE = 16000
MAX_TTS_RETRIES = 3 # Maximum number of TTS attempts per query
RETRY_DELAY_SECONDS = 2 # Optional delay between retries
# ------------------------
# 辅助函数 (Identical to the previous version with retry logic)
# ------------------------
def get_random_prompt(common_voice_dataset, sample_rate=SAMPLE_RATE):
"""
从 VoxPopuli (替代原 Common Voice) 数据集中随机抽取一条语音及对应文本作为 prompt。
"""
idx = random.randint(0, len(common_voice_dataset) - 1)
sample = common_voice_dataset.select([idx])[0]
audio = sample['audio']
waveform = torch.tensor(audio['array'], dtype=torch.float32)
sr = audio['sampling_rate']
if sr != sample_rate:
if waveform.dim() > 1:
waveform = waveform.mean(dim=0)
resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate)
waveform = resampler(waveform)
if waveform.dim() == 1:
waveform = waveform.unsqueeze(0)
if waveform.numel() == 0 or not sample['raw_text']:
print("Warning: Got an empty prompt, trying again...")
return get_random_prompt(common_voice_dataset, sample_rate)
return waveform, sample['raw_text']
def text_to_audio(query_text, cosyvoice, common_voice_dataset, stream=False, max_retries=MAX_TTS_RETRIES):
"""
利用 CosyVoice2 模型将输入文本转换为语音,采用随机 prompt 进行零样本推理。
Includes retry logic on failure.
"""
last_exception = None
for attempt in range(max_retries):
try:
prompt_speech, prompt_text = get_random_prompt(common_voice_dataset, sample_rate=SAMPLE_RATE)
all_speech = []
inference_generator = cosyvoice.inference_zero_shot(
query_text,
prompt_text,
prompt_speech,
stream=stream,
text_frontend=False
)
for i, chunk in enumerate(inference_generator):
if 'tts_speech' in chunk and chunk['tts_speech'] is not None:
all_speech.append(chunk['tts_speech'])
else:
print(f"Warning: Chunk {i} missing 'tts_speech' for text '{query_text[:60]}...'")
if not all_speech:
raise ValueError("TTS inference finished but produced no audio chunks.")
combined_speech = torch.cat(all_speech, dim=-1)
sample_rate_val = cosyvoice.sample_rate
return {
'audio_tensor': combined_speech,
'sample_rate': sample_rate_val
}
except Exception as e:
last_exception = e
print(f"Error converting text to audio on attempt {attempt + 1}/{max_retries}: {e}")
print(f"Text: '{query_text[:100]}...'")
print(f"Prompt Text: '{prompt_text[:100]}...'")
if attempt < max_retries - 1:
print(f"Retrying with a different prompt in {RETRY_DELAY_SECONDS}s...")
time.sleep(RETRY_DELAY_SECONDS)
else:
print(f"All {max_retries} TTS attempts failed.")
print(f"Failed to generate audio for text after {max_retries} attempts: '{query_text[:60]}...'")
print(f"Last error: {last_exception}")
return None
def process_example(example, cosyvoice, common_voice_dataset, sample_rate=SAMPLE_RATE):
"""
针对从磁盘加载的 PKU-SafeRLHF 过滤后数据集中的单个样本进行 TTS 处理。
Processes example['prompt']. <--- Changed from 'query'/'question'
"""
# --- Target the 'prompt' field from the filtered PKU-SafeRLHF dataset ---
query = example.get('prompt') # <--- Use 'prompt' field
if not query or not isinstance(query, str) or query.strip() == "":
print(f"Warning: Skipping example due to missing or empty 'prompt' field: {example.keys()}") # Log keys if prompt is missing
return None
# --- Use the text_to_audio function with retry logic ---
audio_result = text_to_audio(query, cosyvoice, common_voice_dataset, stream=False)
if audio_result is not None:
audio_tensor = audio_result['audio_tensor']
if audio_tensor.dim() == 1:
audio_tensor = audio_tensor.unsqueeze(0)
elif audio_tensor.dim() > 2:
print(f"Warning: Generated audio tensor has unexpected shape {audio_tensor.shape}. Attempting to flatten.")
audio_tensor = audio_tensor.view(1, -1)
if audio_tensor.numel() == 0:
print(f"Warning: Generated audio tensor is empty for prompt: '{query[:60]}...'")
return None
return {
'audio_tensor': audio_tensor,
'sample_rate': audio_result['sample_rate']
}
else:
return None
# ------------------------
# 数据加载与模型初始化
# ------------------------
print("Loading VoxPopuli (as Common Voice) dataset for prompts...")
try:
common_voice = load_dataset("facebook/voxpopuli", "en", split='train')
print(f"Total VoxPopuli {COMMON_VOICE_LANGUAGE} samples: {len(common_voice)}")
if len(common_voice) == 0:
raise ValueError("VoxPopuli dataset loaded but contains no samples.")
except Exception as e:
print(f"Error loading VoxPopuli dataset: {e}")
sys.exit(1)
print("Initializing CosyVoice2 model...")
try:
cosyvoice = CosyVoice2(
'/root/autodl-tmp/CosyVoice/pretrained_models/CosyVoice2-0.5B', # Verify this path is correct
load_jit=True,
load_trt=False,
fp16=False
)
except Exception as e:
print(f"Error initializing CosyVoice2 model: {e}")
sys.exit(1)
print(f"Loading pre-filtered PKU-SafeRLHF dataset from disk: {FILTERED_DATASET_PATH}")
try:
# --- Load the dataset saved by your first script ---
filtered_dataset = load_from_disk(FILTERED_DATASET_PATH)
if not filtered_dataset:
raise ValueError(f"Dataset loaded from '{FILTERED_DATASET_PATH}' is empty or invalid.")
print(f"Successfully loaded dataset with {len(filtered_dataset)} examples.")
# --- Assume the loaded dataset corresponds to a single split (e.g., 'train') ---
# Wrap it in a dictionary to match the structure expected by the loop below
dataset_dict = {"train": filtered_dataset} # Use "train" as the key, matching the split processed by the filter script
# Alternatively, if you know the split name used in the filter script was different, use that name.
except FileNotFoundError:
print(f"Error: Pre-filtered dataset not found at '{FILTERED_DATASET_PATH}'.")
print("Please ensure the first script ran successfully and saved the data to the correct location.")
sys.exit(1)
except Exception as e:
print(f"Error loading pre-filtered dataset from '{FILTERED_DATASET_PATH}': {e}")
sys.exit(1)
# 创建输出目录
os.makedirs(OUTPUT_DATASET_PATH, exist_ok=True)
# ------------------------
# 主处理循环
# ------------------------
final_dataset_dict = {} # 存放各 split 最终处理后的数据
# Iterate through the splits defined in dataset_dict (should just be 'train' in this case)
for split_name, split_dataset in dataset_dict.items():
print(f"Processing split: {split_name} with {len(split_dataset)} examples")
split_output_dir = os.path.join(OUTPUT_DATASET_PATH, split_name)
os.makedirs(split_output_dir, exist_ok=True)
# 用于断点续跑的进度记录
progress_file = os.path.join(split_output_dir, "progress.txt")
start_index = 0
if os.path.exists(progress_file):
try:
with open(progress_file, "r") as f:
content = f.read().strip()
if content:
start_index = int(content)
print(f"Resuming split '{split_name}' from sample index {start_index}")
else:
print(f"Progress file '{progress_file}' is empty, starting from index 0.")
start_index = 0
except ValueError:
print(f"Could not parse integer from progress file '{progress_file}'. Starting from index 0.")
start_index = 0
except Exception as e:
print(f"Error reading progress file '{progress_file}': {e}. Starting from index 0.")
start_index = 0
final_samples = []
# 遍历处理每条样本
pbar = tqdm(range(start_index, len(split_dataset)), desc=f"Processing {split_name}", initial=start_index, total=len(split_dataset))
for i in pbar:
sample = split_dataset[i]
output_wav_path = os.path.join(split_output_dir, f"{i}.wav")
if os.path.exists(output_wav_path):
sample_dict = {k: sample[k] for k in sample.keys()}
sample_dict["audio_filepath"] = output_wav_path
final_samples.append(sample_dict)
with open(progress_file, "w") as f:
f.write(str(i + 1))
continue
# --- Perform TTS on the 'prompt' field ---
result = process_example(sample, cosyvoice, common_voice, sample_rate=SAMPLE_RATE)
if result is not None:
audio_tensor = result['audio_tensor']
sample_rate_val = result['sample_rate']
try:
audio_tensor_save = audio_tensor.detach().cpu().to(torch.float32)
if audio_tensor_save.dim() == 1:
audio_tensor_save = audio_tensor_save.unsqueeze(0)
elif audio_tensor_save.dim() > 2:
audio_tensor_save = audio_tensor_save.view(1, -1)
torchaudio.save(output_wav_path, audio_tensor_save, sample_rate_val)
# Preserve all original fields from the filtered dataset + add audio path
sample_dict = {k: sample[k] for k in sample.keys()}
sample_dict["audio_filepath"] = output_wav_path
final_samples.append(sample_dict)
except Exception as e:
print(f"Failed to save wav for sample {i} at {output_wav_path}: {e}")
else:
print(f"Sample {i} processing failed after retries (Prompt: '{sample.get('prompt', 'N/A')[:60]}...'), no audio generated.")
# Update progress file after processing each sample
with open(progress_file, "w") as f:
f.write(str(i + 1))
# Generate Hugging Face Dataset from the collected successful samples and save
if final_samples:
final_dataset_obj = Dataset.from_list(final_samples)
final_dataset_save_path = os.path.join(split_output_dir, "final_dataset")
try:
print(f"Saving final dataset for split '{split_name}' to {final_dataset_save_path}...")
final_dataset_obj.save_to_disk(final_dataset_save_path)
print(f"Finished processing split: {split_name}. Saved {len(final_samples)} samples with audio paths.")
final_dataset_dict[split_name] = final_dataset_obj
except Exception as e:
print(f"Error saving final dataset for split '{split_name}' to disk: {e}")
else:
print(f"Finished processing split: {split_name}. No samples were successfully processed or saved.")
print("="*30)
if final_dataset_dict:
print(f"All specified splits processed. Final datasets saved in respective subdirectories within '{OUTPUT_DATASET_PATH}'.")
print(f"Processed splits: {list(final_dataset_dict.keys())}")
else:
print(f"Processing finished, but no final datasets were generated or saved in '{OUTPUT_DATASET_PATH}'. Check logs for errors.")
print("="*30)