1f's picture
Add files using upload-large-folder tool
ff53362 verified
import os
import json
import base64
import uuid
import time
import re
import random
import concurrent.futures
from tqdm import tqdm
import threading
import traceback # For detailed error logging
import requests # Use requests library for HTTP calls
# Make sure numpy is imported if needed for potential fallback serialization
import numpy as np
from datasets import load_from_disk, Dataset
from dotenv import load_dotenv
# --- Configuration ---
load_dotenv()
# 1. API Client Setup
GPT4O_MODEL_NAME = "freeze_omni" # How it's identified in your dataset's model columns
API_MODEL_NAME = "gpt-4o-mini-audio-preview" # Actual model name for the API call
API_ENDPOINT = "https://api2.aigcbest.top/v1/chat/completions"
try:
# Assuming a single key for this service based on the original script
API_TOKEN = "sk-J6Y4OBCEG0D75suEZoj22eFmiwO1DHzLCqvt4bRmyZRTMlTa"
if not API_TOKEN:
raise ValueError("AIGCBEST_API_KEY environment variable not set.")
print("AIGCBEST API Key loaded.")
except Exception as e:
print(f"FATAL: Error getting API Key: {e}")
exit(1)
# 2. Dataset Paths
INPUT_DATASET_DIR = "/root/autodl-tmp/audio-r1/r1-a/dataset/preference_sampling_tasks"
OUTPUT_DATASET_DIR = "/root/autodl-tmp/audio-r1/r1-a/dataset/preference_tasks_with_gpt4o_mini"
# 3. Output Audio Configuration
OUTPUT_AUDIO_ROOT_DIR = "/root/autodl-tmp/audio-r1/r1-a/generated_audio/gpt4o_mini"
OUTPUT_AUDIO_FORMAT = "wav" # API will be requested to return wav
AVAILABLE_VOICES = ['alloy', 'ash', 'ballad', 'coral', 'echo', 'sage', 'shimmer', 'verse']
# 4. API Call Settings
API_TIMEOUT = 240
API_RETRY_DELAY = 5
API_MAX_RETRIES = 3 # Max attempts *for the task*
MAX_WORKERS = 8 # Adjust based on API rate limits and system resources
# 5. Checkpoint Saving Configuration # <-- NEW
CHECKPOINT_INTERVAL = 500 # Save every 500 completed tasks
# --- Helper Functions (encode_audio_base64 and parse_ultra_history remain the same) ---
def encode_audio_base64(audio_path):
if not audio_path or not os.path.exists(audio_path):
print(f"Warning: Input audio file not found or path is empty: {audio_path}")
return None
try:
with open(audio_path, "rb") as audio_file:
return base64.b64encode(audio_file.read()).decode("utf-8")
except Exception as e:
print(f"Error encoding audio file {audio_path}: {e}")
return None
def parse_ultra_history(history_str):
messages = []
pattern = re.compile(r"\[(USER|ASSISTANT)\]\s*([\s\S]*?)(?=\s*\[(?:USER|ASSISTANT)\]|$)")
matches = pattern.findall(history_str)
if not matches:
return []
for role_tag, content in matches:
role = role_tag.lower()
cleaned_content = content.strip()
if cleaned_content:
messages.append({"role": role, "content": cleaned_content})
return messages
# --- Modified API Call Worker Function for GPT-4o (Reduced Prints) ---
def call_gpt4o_api_worker(task_info):
"""
Worker function to call the custom GPT-4o API for a single task.
"""
row_idx = task_info["row_idx"]
slot_idx = task_info["slot_idx"]
history_messages = task_info["history_messages"]
prompt_text = task_info["prompt_text"]
question_text = task_info["question_text"]
question_audio_path = task_info["question_audio_path"]
output_audio_filepath = task_info["output_audio_filepath"]
retries = 0
headers = {
'Accept': 'application/json',
'Authorization': f'Bearer {API_TOKEN}', # Use the single loaded token
'Content-Type': 'application/json'
}
selected_voice = random.choice(AVAILABLE_VOICES)
# print(f" [Thread-{threading.get_ident()}] Processing Row {row_idx}, Slot {slot_idx} (GPT4o Voice: {selected_voice})") # Optional log
while retries < API_MAX_RETRIES:
try:
# 1. Prepare Input Audio
base64_audio_data = encode_audio_base64(question_audio_path)
if not base64_audio_data:
print(f"Error (Row {row_idx}, Slot {slot_idx}): Skipping GPT4o API call - missing input audio.")
return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": "[ERROR: Missing input audio]", "saved_audio_path": None}
input_audio_format = os.path.splitext(question_audio_path)[1].lstrip('.') or 'wav'
# 2. Construct User Message Content
combined_text = f"{prompt_text}"
user_content_list = [
{"type": "text", "text": combined_text},
{"type": "input_audio", "input_audio": {"data": base64_audio_data, "format": input_audio_format}}
]
messages = history_messages + [{"role": "user", "content": user_content_list}]
# 4. Construct Payload
payload = {
"model": API_MODEL_NAME,
"modalities": ["text", "audio"],
"audio": {"voice": selected_voice, "format": OUTPUT_AUDIO_FORMAT},
"messages": messages
}
# 5. Make API Call
response = requests.post(
API_ENDPOINT,
headers=headers,
json=payload,
timeout=API_TIMEOUT
)
# 6. Process Response
if response.status_code == 200:
try:
response_data = response.json()
# Make parsing more robust
choices = response_data.get('choices')
if not choices or not isinstance(choices, list) or len(choices) == 0:
raise ValueError("Invalid or empty 'choices' field in response.")
message_content = choices[0].get('message', {})
if not message_content:
raise ValueError("Missing 'message' field in the first choice.")
audio_info = message_content.get('audio', {})
if not isinstance(audio_info, dict): audio_info = {} # Handle case where audio might be null or not a dict
audio_base64_string = audio_info.get('data', '')
# Try getting text from 'content' if 'transcript' is missing/empty in 'audio'
collected_text = audio_info.get('transcript', '').strip()
if not collected_text:
text_content_list = message_content.get('content', [])
if isinstance(text_content_list, list):
for item in text_content_list:
if isinstance(item, dict) and item.get("type") == "text":
collected_text = item.get("text", "").strip()
break # Take the first text part found
# Still no text? Try the top-level message content directly if it's a string
elif isinstance(message_content.get('content'), str):
collected_text = message_content['content'].strip()
if not collected_text: print(f"Warning (Row {row_idx}, Slot {slot_idx}): No text content found after checking multiple fields.")
if not audio_base64_string: print(f"Warning (Row {row_idx}, Slot {slot_idx}): No audio data found.")
saved_audio_path = None
if audio_base64_string:
try:
wav_bytes = base64.b64decode(audio_base64_string)
if len(wav_bytes) == 0:
print(f"Warning (Row {row_idx}, Slot {slot_idx}): Decoded audio bytes are empty.")
else:
os.makedirs(os.path.dirname(output_audio_filepath), exist_ok=True)
with open(output_audio_filepath, "wb") as f:
f.write(wav_bytes)
saved_audio_path = output_audio_filepath
# print(f" Audio saved to: {output_audio_filepath}") # Less verbose log
except base64.binascii.Error as b64_err:
print(f"Error (Row {row_idx}, Slot {slot_idx}): Decoding base64 audio data failed: {b64_err}")
except Exception as e:
print(f"Error (Row {row_idx}, Slot {slot_idx}): Saving audio file failed: {e}")
# TASK SUCCEEDED (even if audio saving failed, text might be valid)
return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": collected_text, "saved_audio_path": saved_audio_path}
except (json.JSONDecodeError, IndexError, KeyError, TypeError, ValueError) as e:
print(f"Error (Row {row_idx}, Slot {slot_idx}): Parsing successful API response failed: {type(e).__name__} - {e}")
print(f" Response Text (start): {response.text[:500]}...")
retries += 1
print(f" Retrying task ({retries}/{API_MAX_RETRIES})...")
time.sleep(API_RETRY_DELAY)
continue
except Exception as e: # Catch-all for unexpected errors during processing
print(f"Error (Row {row_idx}, Slot {slot_idx}): Unexpected error processing response: {e}")
print(traceback.format_exc())
retries += 1
print(f" Retrying task ({retries}/{API_MAX_RETRIES})...")
time.sleep(API_RETRY_DELAY)
continue
else: # Handle non-200 status codes
print(f"Error (Row {row_idx}, Slot {slot_idx}): API returned status {response.status_code}. Response: {response.text[:500]}...")
retries += 1
if retries < API_MAX_RETRIES:
print(f" Retrying task ({retries}/{API_MAX_RETRIES})...")
time.sleep(API_RETRY_DELAY)
continue # Go to next iteration of while loop
else:
print(f"Error (Row {row_idx}, Slot {slot_idx}): Max retries reached after status {response.status_code}.")
return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": f"[API ERROR: Status {response.status_code}]", "saved_audio_path": None}
except requests.exceptions.Timeout:
retries += 1
print(f"Error (Row {row_idx}, Slot {slot_idx}): API Call Attempt {retries}/{API_MAX_RETRIES} timed out after {API_TIMEOUT}s.")
if retries < API_MAX_RETRIES:
time.sleep(API_RETRY_DELAY)
continue
else:
print(f"Error (Row {row_idx}, Slot {slot_idx}): Max retries reached after timeout.")
return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": "[API ERROR: Timeout]", "saved_audio_path": None}
except requests.exceptions.RequestException as e:
retries += 1
print(f"Error (Row {row_idx}, Slot {slot_idx}): Network/Request Error Attempt {retries}/{API_MAX_RETRIES}: {e}")
if retries < API_MAX_RETRIES:
time.sleep(API_RETRY_DELAY)
continue
else:
print(f"Error (Row {row_idx}, Slot {slot_idx}): Max retries reached after network error.")
return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": "[API ERROR: Network Error]", "saved_audio_path": None}
except Exception as e:
retries += 1
print(f"Error (Row {row_idx}, Slot {slot_idx}): Unexpected Error in Worker Loop Attempt {retries}/{API_MAX_RETRIES}: {type(e).__name__} - {e}")
print(traceback.format_exc())
if retries < API_MAX_RETRIES:
time.sleep(API_RETRY_DELAY)
continue
else:
print(f"Error (Row {row_idx}, Slot {slot_idx}): Max retries reached after unexpected error.")
return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": "[API ERROR: Unexpected Worker Error]", "saved_audio_path": None}
# If loop finishes without returning, max retries were hit
print(f"Error (Row {row_idx}, Slot {slot_idx}): Task failed after {API_MAX_RETRIES} attempts.")
return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": "[API ERROR: Max retries reached]", "saved_audio_path": None}
# --- Checkpoint Saving Function --- # <-- NEW (Copied from previous response)
def save_checkpoint(data_to_save, output_dir, dataset_features):
"""Saves the current state of the data to disk."""
if not data_to_save:
print("Checkpoint: No data available to save.")
return
# Ensure output directory exists before saving
os.makedirs(output_dir, exist_ok=True)
print(f"\nCheckpoint: Saving {len(data_to_save)} rows to {output_dir}...")
try:
# Convert list of dicts back to Dataset object
checkpoint_dataset = Dataset.from_list(data_to_save, features=dataset_features)
checkpoint_dataset.save_to_disk(output_dir)
print(f"Checkpoint: Saved successfully to {output_dir}")
except Exception as ckpt_save_e:
print(f"Error saving checkpoint dataset using datasets lib: {ckpt_save_e}")
# Fallback to JSON Lines (optional, but good practice)
output_jsonl_path = os.path.join(output_dir, "checkpoint_data.jsonl") # Save inside the dir
print(f"Attempting to save checkpoint as JSON lines to {output_jsonl_path}...")
try:
with open(output_jsonl_path, 'w', encoding='utf-8') as f:
for item in data_to_save:
# Basic serialization handling for common types like numpy arrays
serializable_item = {k: (v.tolist() if isinstance(v, np.ndarray) else v) for k, v in item.items()}
f.write(json.dumps(serializable_item, ensure_ascii=False) + '\n')
print(f"Checkpoint: Fallback save successful to {output_jsonl_path}")
except Exception as json_save_e:
print(f"Error saving checkpoint as JSON lines: {json_save_e}")
# --- Main Processing Logic ---
print("Checking for existing checkpoint/output dataset...")
dataset = None
original_features = None # Initialize
try:
# 检查输出目录是否存在,并且看起来像一个 Hugging Face datasets 目录
# (dataset_info.json 或 state.json 是常见的指示文件)
potential_checkpoint_info = os.path.join(OUTPUT_DATASET_DIR, "dataset_info.json")
potential_checkpoint_state = os.path.join(OUTPUT_DATASET_DIR, "state.json")
if os.path.exists(OUTPUT_DATASET_DIR) and \
(os.path.exists(potential_checkpoint_info) or os.path.exists(potential_checkpoint_state)):
print(f"Attempting to load existing data from output directory: {OUTPUT_DATASET_DIR}")
try:
dataset = load_from_disk(OUTPUT_DATASET_DIR)
original_features = dataset.features # 获取已保存数据集的特征
print(f"Successfully resumed from {OUTPUT_DATASET_DIR}. Loaded {len(dataset)} rows.")
except Exception as load_ckpt_e:
print(f"Warning: Failed to load from {OUTPUT_DATASET_DIR}: {load_ckpt_e}")
print("Falling back to loading original input dataset.")
dataset = None # Ensure we proceed to load original if checkpoint load failed
else:
print(f"No valid existing data found in {OUTPUT_DATASET_DIR}.")
# If no checkpoint, ensure dataset is None so original loading happens
# 如果 dataset 仍然是 None (因为没有找到 checkpoint 或加载失败)
if dataset is None:
print(f"Loading original dataset from {INPUT_DATASET_DIR}...")
dataset = load_from_disk(INPUT_DATASET_DIR)
original_features = dataset.features
print(f"Original dataset loaded successfully with {len(dataset)} rows.")
except Exception as initial_load_e:
print(f"FATAL: Error during initial dataset loading (original or checkpoint): {initial_load_e}")
print(traceback.format_exc()) # 打印详细错误
exit(1)
os.makedirs(OUTPUT_AUDIO_ROOT_DIR, exist_ok=True)
# --- Pre-calculation Step for GPT-4o ---
print("Pre-calculating GPT-4o tasks...")
tasks_to_process = []
# Use a list of dictionaries, which is mutable and easier for direct updates
updated_data = list(dataset)
for idx, row in enumerate(tqdm(updated_data, desc="Scanning dataset for GPT-4o tasks")):
for i in range(1, 4):
model_key = f"model_{i}"
response_text_key = f"response_text_{i}"
prompt_text_key = f"prompt_text_{i}"
response_audio_key = f"response_audio_path_{i}" # Key for storing the *new* audio path
model_assigned = row.get(model_key)
response_text_exists = row.get(response_text_key) is not None
# Check for the specific model name used in the dataset
if model_assigned == GPT4O_MODEL_NAME and not response_text_exists:
question_audio_path = row.get('question_audio')
if not question_audio_path or not os.path.exists(question_audio_path): # Check path validity here
print(f"Warning (Row {idx}, Slot {i}): Skipping GPT-4o task - Missing or invalid 'question_audio' path: {question_audio_path}")
# Pre-fill error? Let's just skip task creation for now.
# If needed: updated_data[idx][response_text_key] = "[ERROR: Missing input audio]"
# If needed: updated_data[idx][response_audio_key] = None
continue # Skip this task
metadata_str = row.get('metadata', "{}")
source_dataset = row.get('source_dataset')
metadata = {}
try:
if metadata_str and isinstance(metadata_str, str): metadata = json.loads(metadata_str)
elif isinstance(metadata_str, dict): metadata = metadata_str
except json.JSONDecodeError: pass
history_messages = []
if source_dataset == 'ultra':
history_str = metadata.get('history', '')
if history_str: history_messages = parse_ultra_history(history_str)
unique_id = str(uuid.uuid4()).replace("-", "")
output_audio_filename = f"gpt4o_r{idx}_s{i}_{unique_id}.{OUTPUT_AUDIO_FORMAT}"
output_audio_filepath = os.path.join(OUTPUT_AUDIO_ROOT_DIR, output_audio_filename)
task_info = {
"row_idx": idx,
"slot_idx": i,
# No API key needed here as it's global/single
"history_messages": history_messages,
"prompt_text": row.get(prompt_text_key, ""),
"question_text": row.get('question_text', ""), # Pass question text
"question_audio_path": question_audio_path,
"output_audio_filepath": output_audio_filepath,
}
tasks_to_process.append(task_info)
# Decide if you process all slots or just the first unfilled one
# break # Uncomment this line if you only want the *first* unfilled gpt4o slot per row processed
total_tasks = len(tasks_to_process)
if total_tasks == 0:
print("No GPT-4o tasks found needing processing.")
exit(0)
print(f"Found {total_tasks} GPT-4o tasks to process.")
# --- Threaded Execution with Checkpointing for GPT-4o --- # <-- MODIFIED SECTION
print(f"Starting GPT-4o processing with up to {MAX_WORKERS} worker threads...")
start_total_time = time.time()
# results = {} # No longer needed
tasks_completed = 0
tasks_failed = 0
completed_since_last_save = 0 # <-- Counter for checkpointing
# Use context manager for ThreadPoolExecutor
with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
future_to_task = {executor.submit(call_gpt4o_api_worker, task): task for task in tasks_to_process}
for future in tqdm(concurrent.futures.as_completed(future_to_task), total=total_tasks, desc="Processing GPT-4o tasks"):
task_info = future_to_task[future] # Get original task info
row_idx = task_info["row_idx"]
slot_idx = task_info["slot_idx"]
result = None # Define result scope
try:
result = future.result()
# --- Direct Update and Checkpointing Logic ---
response_text_key = f"response_text_{slot_idx}"
response_audio_key = f"response_audio_path_{slot_idx}"
if 0 <= row_idx < len(updated_data):
updated_data[row_idx][response_text_key] = result["response_text"]
updated_data[row_idx][response_audio_key] = result["saved_audio_path"]
if result["saved_audio_path"] is None or "[ERROR" in result["response_text"]: # Check for error marker
tasks_failed += 1
else:
print(f"Warning: Invalid row index {row_idx} encountered during result merge. Skipping update.")
tasks_failed += 1 # Count as failed if index is bad
tasks_completed += 1
completed_since_last_save += 1 # Increment checkpoint counter
# Check if it's time to save a checkpoint
if completed_since_last_save >= CHECKPOINT_INTERVAL:
save_checkpoint(updated_data, OUTPUT_DATASET_DIR, original_features)
completed_since_last_save = 0 # Reset counter
except Exception as exc: # Catch exceptions raised *by* the future/worker if not handled inside
print(f"Error (Row {row_idx}, Slot {slot_idx}): GPT-4o Task generated an unhandled exception: {exc}")
print(traceback.format_exc())
# Attempt to record error in the main data structure
response_text_key = f"response_text_{slot_idx}"
response_audio_key = f"response_audio_path_{slot_idx}"
if 0 <= row_idx < len(updated_data):
updated_data[row_idx][response_text_key] = f"[ERROR: Worker Crash - {exc}]"
updated_data[row_idx][response_audio_key] = None
else:
print(f"Warning: Invalid row index {row_idx} encountered during exception handling merge.")
tasks_failed += 1
tasks_completed += 1 # Count as completed (though failed)
completed_since_last_save += 1 # Also increment for checkpointing
# Check if it's time to save a checkpoint even after an error
if completed_since_last_save >= CHECKPOINT_INTERVAL:
save_checkpoint(updated_data, OUTPUT_DATASET_DIR, original_features)
completed_since_last_save = 0 # Reset counter
end_total_time = time.time()
print("\n--- GPT-4o Processing Complete ---")
print(f"Total GPT-4o tasks processed: {tasks_completed} (Succeeded: {tasks_completed - tasks_failed}, Failed: {tasks_failed})")
print(f"Total GPT-4o processing time: {(end_total_time - start_total_time)/60:.2f} minutes")
# --- Final Save ---
# Save one last time to ensure any remaining processed items (< CHECKPOINT_INTERVAL) are saved
print("\nPerforming final save...")
save_checkpoint(updated_data, OUTPUT_DATASET_DIR, original_features)
print("\nScript finished.")
# --- (Removed the old merging and saving logic as it's now handled by save_checkpoint) ---