|
|
import os |
|
|
import json |
|
|
import base64 |
|
|
import uuid |
|
|
import time |
|
|
import re |
|
|
import random |
|
|
import concurrent.futures |
|
|
from tqdm import tqdm |
|
|
import threading |
|
|
import traceback |
|
|
|
|
|
import requests |
|
|
|
|
|
import numpy as np |
|
|
from datasets import load_from_disk, Dataset |
|
|
from dotenv import load_dotenv |
|
|
|
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
|
|
|
GPT4O_MODEL_NAME = "freeze_omni" |
|
|
API_MODEL_NAME = "gpt-4o-mini-audio-preview" |
|
|
API_ENDPOINT = "https://api2.aigcbest.top/v1/chat/completions" |
|
|
try: |
|
|
|
|
|
API_TOKEN = "sk-J6Y4OBCEG0D75suEZoj22eFmiwO1DHzLCqvt4bRmyZRTMlTa" |
|
|
if not API_TOKEN: |
|
|
raise ValueError("AIGCBEST_API_KEY environment variable not set.") |
|
|
print("AIGCBEST API Key loaded.") |
|
|
except Exception as e: |
|
|
print(f"FATAL: Error getting API Key: {e}") |
|
|
exit(1) |
|
|
|
|
|
|
|
|
INPUT_DATASET_DIR = "/root/autodl-tmp/audio-r1/r1-a/dataset/preference_sampling_tasks" |
|
|
OUTPUT_DATASET_DIR = "/root/autodl-tmp/audio-r1/r1-a/dataset/preference_tasks_with_gpt4o_mini" |
|
|
|
|
|
|
|
|
OUTPUT_AUDIO_ROOT_DIR = "/root/autodl-tmp/audio-r1/r1-a/generated_audio/gpt4o_mini" |
|
|
OUTPUT_AUDIO_FORMAT = "wav" |
|
|
AVAILABLE_VOICES = ['alloy', 'ash', 'ballad', 'coral', 'echo', 'sage', 'shimmer', 'verse'] |
|
|
|
|
|
|
|
|
API_TIMEOUT = 240 |
|
|
API_RETRY_DELAY = 5 |
|
|
API_MAX_RETRIES = 3 |
|
|
MAX_WORKERS = 8 |
|
|
|
|
|
|
|
|
CHECKPOINT_INTERVAL = 500 |
|
|
|
|
|
|
|
|
|
|
|
def encode_audio_base64(audio_path): |
|
|
if not audio_path or not os.path.exists(audio_path): |
|
|
print(f"Warning: Input audio file not found or path is empty: {audio_path}") |
|
|
return None |
|
|
try: |
|
|
with open(audio_path, "rb") as audio_file: |
|
|
return base64.b64encode(audio_file.read()).decode("utf-8") |
|
|
except Exception as e: |
|
|
print(f"Error encoding audio file {audio_path}: {e}") |
|
|
return None |
|
|
|
|
|
def parse_ultra_history(history_str): |
|
|
messages = [] |
|
|
pattern = re.compile(r"\[(USER|ASSISTANT)\]\s*([\s\S]*?)(?=\s*\[(?:USER|ASSISTANT)\]|$)") |
|
|
matches = pattern.findall(history_str) |
|
|
if not matches: |
|
|
return [] |
|
|
for role_tag, content in matches: |
|
|
role = role_tag.lower() |
|
|
cleaned_content = content.strip() |
|
|
if cleaned_content: |
|
|
messages.append({"role": role, "content": cleaned_content}) |
|
|
return messages |
|
|
|
|
|
|
|
|
def call_gpt4o_api_worker(task_info): |
|
|
""" |
|
|
Worker function to call the custom GPT-4o API for a single task. |
|
|
""" |
|
|
row_idx = task_info["row_idx"] |
|
|
slot_idx = task_info["slot_idx"] |
|
|
history_messages = task_info["history_messages"] |
|
|
prompt_text = task_info["prompt_text"] |
|
|
question_text = task_info["question_text"] |
|
|
question_audio_path = task_info["question_audio_path"] |
|
|
output_audio_filepath = task_info["output_audio_filepath"] |
|
|
|
|
|
retries = 0 |
|
|
headers = { |
|
|
'Accept': 'application/json', |
|
|
'Authorization': f'Bearer {API_TOKEN}', |
|
|
'Content-Type': 'application/json' |
|
|
} |
|
|
selected_voice = random.choice(AVAILABLE_VOICES) |
|
|
|
|
|
|
|
|
while retries < API_MAX_RETRIES: |
|
|
try: |
|
|
|
|
|
base64_audio_data = encode_audio_base64(question_audio_path) |
|
|
if not base64_audio_data: |
|
|
print(f"Error (Row {row_idx}, Slot {slot_idx}): Skipping GPT4o API call - missing input audio.") |
|
|
return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": "[ERROR: Missing input audio]", "saved_audio_path": None} |
|
|
|
|
|
input_audio_format = os.path.splitext(question_audio_path)[1].lstrip('.') or 'wav' |
|
|
|
|
|
|
|
|
combined_text = f"{prompt_text}" |
|
|
user_content_list = [ |
|
|
{"type": "text", "text": combined_text}, |
|
|
{"type": "input_audio", "input_audio": {"data": base64_audio_data, "format": input_audio_format}} |
|
|
] |
|
|
messages = history_messages + [{"role": "user", "content": user_content_list}] |
|
|
|
|
|
|
|
|
payload = { |
|
|
"model": API_MODEL_NAME, |
|
|
"modalities": ["text", "audio"], |
|
|
"audio": {"voice": selected_voice, "format": OUTPUT_AUDIO_FORMAT}, |
|
|
"messages": messages |
|
|
} |
|
|
|
|
|
|
|
|
response = requests.post( |
|
|
API_ENDPOINT, |
|
|
headers=headers, |
|
|
json=payload, |
|
|
timeout=API_TIMEOUT |
|
|
) |
|
|
|
|
|
|
|
|
if response.status_code == 200: |
|
|
try: |
|
|
response_data = response.json() |
|
|
|
|
|
choices = response_data.get('choices') |
|
|
if not choices or not isinstance(choices, list) or len(choices) == 0: |
|
|
raise ValueError("Invalid or empty 'choices' field in response.") |
|
|
|
|
|
message_content = choices[0].get('message', {}) |
|
|
if not message_content: |
|
|
raise ValueError("Missing 'message' field in the first choice.") |
|
|
|
|
|
audio_info = message_content.get('audio', {}) |
|
|
if not isinstance(audio_info, dict): audio_info = {} |
|
|
|
|
|
audio_base64_string = audio_info.get('data', '') |
|
|
|
|
|
collected_text = audio_info.get('transcript', '').strip() |
|
|
if not collected_text: |
|
|
text_content_list = message_content.get('content', []) |
|
|
if isinstance(text_content_list, list): |
|
|
for item in text_content_list: |
|
|
if isinstance(item, dict) and item.get("type") == "text": |
|
|
collected_text = item.get("text", "").strip() |
|
|
break |
|
|
|
|
|
elif isinstance(message_content.get('content'), str): |
|
|
collected_text = message_content['content'].strip() |
|
|
|
|
|
if not collected_text: print(f"Warning (Row {row_idx}, Slot {slot_idx}): No text content found after checking multiple fields.") |
|
|
if not audio_base64_string: print(f"Warning (Row {row_idx}, Slot {slot_idx}): No audio data found.") |
|
|
|
|
|
saved_audio_path = None |
|
|
if audio_base64_string: |
|
|
try: |
|
|
wav_bytes = base64.b64decode(audio_base64_string) |
|
|
if len(wav_bytes) == 0: |
|
|
print(f"Warning (Row {row_idx}, Slot {slot_idx}): Decoded audio bytes are empty.") |
|
|
else: |
|
|
os.makedirs(os.path.dirname(output_audio_filepath), exist_ok=True) |
|
|
with open(output_audio_filepath, "wb") as f: |
|
|
f.write(wav_bytes) |
|
|
saved_audio_path = output_audio_filepath |
|
|
|
|
|
except base64.binascii.Error as b64_err: |
|
|
print(f"Error (Row {row_idx}, Slot {slot_idx}): Decoding base64 audio data failed: {b64_err}") |
|
|
except Exception as e: |
|
|
print(f"Error (Row {row_idx}, Slot {slot_idx}): Saving audio file failed: {e}") |
|
|
|
|
|
|
|
|
return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": collected_text, "saved_audio_path": saved_audio_path} |
|
|
|
|
|
except (json.JSONDecodeError, IndexError, KeyError, TypeError, ValueError) as e: |
|
|
print(f"Error (Row {row_idx}, Slot {slot_idx}): Parsing successful API response failed: {type(e).__name__} - {e}") |
|
|
print(f" Response Text (start): {response.text[:500]}...") |
|
|
retries += 1 |
|
|
print(f" Retrying task ({retries}/{API_MAX_RETRIES})...") |
|
|
time.sleep(API_RETRY_DELAY) |
|
|
continue |
|
|
except Exception as e: |
|
|
print(f"Error (Row {row_idx}, Slot {slot_idx}): Unexpected error processing response: {e}") |
|
|
print(traceback.format_exc()) |
|
|
retries += 1 |
|
|
print(f" Retrying task ({retries}/{API_MAX_RETRIES})...") |
|
|
time.sleep(API_RETRY_DELAY) |
|
|
continue |
|
|
|
|
|
else: |
|
|
print(f"Error (Row {row_idx}, Slot {slot_idx}): API returned status {response.status_code}. Response: {response.text[:500]}...") |
|
|
retries += 1 |
|
|
if retries < API_MAX_RETRIES: |
|
|
print(f" Retrying task ({retries}/{API_MAX_RETRIES})...") |
|
|
time.sleep(API_RETRY_DELAY) |
|
|
continue |
|
|
else: |
|
|
print(f"Error (Row {row_idx}, Slot {slot_idx}): Max retries reached after status {response.status_code}.") |
|
|
return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": f"[API ERROR: Status {response.status_code}]", "saved_audio_path": None} |
|
|
|
|
|
|
|
|
except requests.exceptions.Timeout: |
|
|
retries += 1 |
|
|
print(f"Error (Row {row_idx}, Slot {slot_idx}): API Call Attempt {retries}/{API_MAX_RETRIES} timed out after {API_TIMEOUT}s.") |
|
|
if retries < API_MAX_RETRIES: |
|
|
time.sleep(API_RETRY_DELAY) |
|
|
continue |
|
|
else: |
|
|
print(f"Error (Row {row_idx}, Slot {slot_idx}): Max retries reached after timeout.") |
|
|
return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": "[API ERROR: Timeout]", "saved_audio_path": None} |
|
|
|
|
|
except requests.exceptions.RequestException as e: |
|
|
retries += 1 |
|
|
print(f"Error (Row {row_idx}, Slot {slot_idx}): Network/Request Error Attempt {retries}/{API_MAX_RETRIES}: {e}") |
|
|
if retries < API_MAX_RETRIES: |
|
|
time.sleep(API_RETRY_DELAY) |
|
|
continue |
|
|
else: |
|
|
print(f"Error (Row {row_idx}, Slot {slot_idx}): Max retries reached after network error.") |
|
|
return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": "[API ERROR: Network Error]", "saved_audio_path": None} |
|
|
|
|
|
except Exception as e: |
|
|
retries += 1 |
|
|
print(f"Error (Row {row_idx}, Slot {slot_idx}): Unexpected Error in Worker Loop Attempt {retries}/{API_MAX_RETRIES}: {type(e).__name__} - {e}") |
|
|
print(traceback.format_exc()) |
|
|
if retries < API_MAX_RETRIES: |
|
|
time.sleep(API_RETRY_DELAY) |
|
|
continue |
|
|
else: |
|
|
print(f"Error (Row {row_idx}, Slot {slot_idx}): Max retries reached after unexpected error.") |
|
|
return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": "[API ERROR: Unexpected Worker Error]", "saved_audio_path": None} |
|
|
|
|
|
|
|
|
print(f"Error (Row {row_idx}, Slot {slot_idx}): Task failed after {API_MAX_RETRIES} attempts.") |
|
|
return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": "[API ERROR: Max retries reached]", "saved_audio_path": None} |
|
|
|
|
|
|
|
|
def save_checkpoint(data_to_save, output_dir, dataset_features): |
|
|
"""Saves the current state of the data to disk.""" |
|
|
if not data_to_save: |
|
|
print("Checkpoint: No data available to save.") |
|
|
return |
|
|
|
|
|
|
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
print(f"\nCheckpoint: Saving {len(data_to_save)} rows to {output_dir}...") |
|
|
try: |
|
|
|
|
|
checkpoint_dataset = Dataset.from_list(data_to_save, features=dataset_features) |
|
|
checkpoint_dataset.save_to_disk(output_dir) |
|
|
print(f"Checkpoint: Saved successfully to {output_dir}") |
|
|
except Exception as ckpt_save_e: |
|
|
print(f"Error saving checkpoint dataset using datasets lib: {ckpt_save_e}") |
|
|
|
|
|
output_jsonl_path = os.path.join(output_dir, "checkpoint_data.jsonl") |
|
|
print(f"Attempting to save checkpoint as JSON lines to {output_jsonl_path}...") |
|
|
try: |
|
|
with open(output_jsonl_path, 'w', encoding='utf-8') as f: |
|
|
for item in data_to_save: |
|
|
|
|
|
serializable_item = {k: (v.tolist() if isinstance(v, np.ndarray) else v) for k, v in item.items()} |
|
|
f.write(json.dumps(serializable_item, ensure_ascii=False) + '\n') |
|
|
print(f"Checkpoint: Fallback save successful to {output_jsonl_path}") |
|
|
except Exception as json_save_e: |
|
|
print(f"Error saving checkpoint as JSON lines: {json_save_e}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("Checking for existing checkpoint/output dataset...") |
|
|
dataset = None |
|
|
original_features = None |
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
potential_checkpoint_info = os.path.join(OUTPUT_DATASET_DIR, "dataset_info.json") |
|
|
potential_checkpoint_state = os.path.join(OUTPUT_DATASET_DIR, "state.json") |
|
|
|
|
|
if os.path.exists(OUTPUT_DATASET_DIR) and \ |
|
|
(os.path.exists(potential_checkpoint_info) or os.path.exists(potential_checkpoint_state)): |
|
|
|
|
|
print(f"Attempting to load existing data from output directory: {OUTPUT_DATASET_DIR}") |
|
|
try: |
|
|
dataset = load_from_disk(OUTPUT_DATASET_DIR) |
|
|
original_features = dataset.features |
|
|
print(f"Successfully resumed from {OUTPUT_DATASET_DIR}. Loaded {len(dataset)} rows.") |
|
|
except Exception as load_ckpt_e: |
|
|
print(f"Warning: Failed to load from {OUTPUT_DATASET_DIR}: {load_ckpt_e}") |
|
|
print("Falling back to loading original input dataset.") |
|
|
dataset = None |
|
|
else: |
|
|
print(f"No valid existing data found in {OUTPUT_DATASET_DIR}.") |
|
|
|
|
|
|
|
|
|
|
|
if dataset is None: |
|
|
print(f"Loading original dataset from {INPUT_DATASET_DIR}...") |
|
|
dataset = load_from_disk(INPUT_DATASET_DIR) |
|
|
original_features = dataset.features |
|
|
print(f"Original dataset loaded successfully with {len(dataset)} rows.") |
|
|
|
|
|
except Exception as initial_load_e: |
|
|
print(f"FATAL: Error during initial dataset loading (original or checkpoint): {initial_load_e}") |
|
|
print(traceback.format_exc()) |
|
|
exit(1) |
|
|
|
|
|
os.makedirs(OUTPUT_AUDIO_ROOT_DIR, exist_ok=True) |
|
|
|
|
|
|
|
|
print("Pre-calculating GPT-4o tasks...") |
|
|
tasks_to_process = [] |
|
|
|
|
|
updated_data = list(dataset) |
|
|
|
|
|
for idx, row in enumerate(tqdm(updated_data, desc="Scanning dataset for GPT-4o tasks")): |
|
|
for i in range(1, 4): |
|
|
model_key = f"model_{i}" |
|
|
response_text_key = f"response_text_{i}" |
|
|
prompt_text_key = f"prompt_text_{i}" |
|
|
response_audio_key = f"response_audio_path_{i}" |
|
|
|
|
|
model_assigned = row.get(model_key) |
|
|
response_text_exists = row.get(response_text_key) is not None |
|
|
|
|
|
|
|
|
if model_assigned == GPT4O_MODEL_NAME and not response_text_exists: |
|
|
question_audio_path = row.get('question_audio') |
|
|
if not question_audio_path or not os.path.exists(question_audio_path): |
|
|
print(f"Warning (Row {idx}, Slot {i}): Skipping GPT-4o task - Missing or invalid 'question_audio' path: {question_audio_path}") |
|
|
|
|
|
|
|
|
|
|
|
continue |
|
|
|
|
|
metadata_str = row.get('metadata', "{}") |
|
|
source_dataset = row.get('source_dataset') |
|
|
metadata = {} |
|
|
try: |
|
|
if metadata_str and isinstance(metadata_str, str): metadata = json.loads(metadata_str) |
|
|
elif isinstance(metadata_str, dict): metadata = metadata_str |
|
|
except json.JSONDecodeError: pass |
|
|
|
|
|
history_messages = [] |
|
|
if source_dataset == 'ultra': |
|
|
history_str = metadata.get('history', '') |
|
|
if history_str: history_messages = parse_ultra_history(history_str) |
|
|
|
|
|
unique_id = str(uuid.uuid4()).replace("-", "") |
|
|
output_audio_filename = f"gpt4o_r{idx}_s{i}_{unique_id}.{OUTPUT_AUDIO_FORMAT}" |
|
|
output_audio_filepath = os.path.join(OUTPUT_AUDIO_ROOT_DIR, output_audio_filename) |
|
|
|
|
|
task_info = { |
|
|
"row_idx": idx, |
|
|
"slot_idx": i, |
|
|
|
|
|
"history_messages": history_messages, |
|
|
"prompt_text": row.get(prompt_text_key, ""), |
|
|
"question_text": row.get('question_text', ""), |
|
|
"question_audio_path": question_audio_path, |
|
|
"output_audio_filepath": output_audio_filepath, |
|
|
} |
|
|
tasks_to_process.append(task_info) |
|
|
|
|
|
|
|
|
|
|
|
total_tasks = len(tasks_to_process) |
|
|
if total_tasks == 0: |
|
|
print("No GPT-4o tasks found needing processing.") |
|
|
exit(0) |
|
|
|
|
|
print(f"Found {total_tasks} GPT-4o tasks to process.") |
|
|
|
|
|
|
|
|
print(f"Starting GPT-4o processing with up to {MAX_WORKERS} worker threads...") |
|
|
start_total_time = time.time() |
|
|
|
|
|
tasks_completed = 0 |
|
|
tasks_failed = 0 |
|
|
completed_since_last_save = 0 |
|
|
|
|
|
|
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor: |
|
|
future_to_task = {executor.submit(call_gpt4o_api_worker, task): task for task in tasks_to_process} |
|
|
|
|
|
for future in tqdm(concurrent.futures.as_completed(future_to_task), total=total_tasks, desc="Processing GPT-4o tasks"): |
|
|
task_info = future_to_task[future] |
|
|
row_idx = task_info["row_idx"] |
|
|
slot_idx = task_info["slot_idx"] |
|
|
result = None |
|
|
|
|
|
try: |
|
|
result = future.result() |
|
|
|
|
|
response_text_key = f"response_text_{slot_idx}" |
|
|
response_audio_key = f"response_audio_path_{slot_idx}" |
|
|
|
|
|
if 0 <= row_idx < len(updated_data): |
|
|
updated_data[row_idx][response_text_key] = result["response_text"] |
|
|
updated_data[row_idx][response_audio_key] = result["saved_audio_path"] |
|
|
if result["saved_audio_path"] is None or "[ERROR" in result["response_text"]: |
|
|
tasks_failed += 1 |
|
|
else: |
|
|
print(f"Warning: Invalid row index {row_idx} encountered during result merge. Skipping update.") |
|
|
tasks_failed += 1 |
|
|
|
|
|
tasks_completed += 1 |
|
|
completed_since_last_save += 1 |
|
|
|
|
|
|
|
|
if completed_since_last_save >= CHECKPOINT_INTERVAL: |
|
|
save_checkpoint(updated_data, OUTPUT_DATASET_DIR, original_features) |
|
|
completed_since_last_save = 0 |
|
|
|
|
|
except Exception as exc: |
|
|
print(f"Error (Row {row_idx}, Slot {slot_idx}): GPT-4o Task generated an unhandled exception: {exc}") |
|
|
print(traceback.format_exc()) |
|
|
|
|
|
response_text_key = f"response_text_{slot_idx}" |
|
|
response_audio_key = f"response_audio_path_{slot_idx}" |
|
|
if 0 <= row_idx < len(updated_data): |
|
|
updated_data[row_idx][response_text_key] = f"[ERROR: Worker Crash - {exc}]" |
|
|
updated_data[row_idx][response_audio_key] = None |
|
|
else: |
|
|
print(f"Warning: Invalid row index {row_idx} encountered during exception handling merge.") |
|
|
|
|
|
tasks_failed += 1 |
|
|
tasks_completed += 1 |
|
|
completed_since_last_save += 1 |
|
|
|
|
|
|
|
|
if completed_since_last_save >= CHECKPOINT_INTERVAL: |
|
|
save_checkpoint(updated_data, OUTPUT_DATASET_DIR, original_features) |
|
|
completed_since_last_save = 0 |
|
|
|
|
|
end_total_time = time.time() |
|
|
print("\n--- GPT-4o Processing Complete ---") |
|
|
print(f"Total GPT-4o tasks processed: {tasks_completed} (Succeeded: {tasks_completed - tasks_failed}, Failed: {tasks_failed})") |
|
|
print(f"Total GPT-4o processing time: {(end_total_time - start_total_time)/60:.2f} minutes") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("\nPerforming final save...") |
|
|
save_checkpoint(updated_data, OUTPUT_DATASET_DIR, original_features) |
|
|
|
|
|
print("\nScript finished.") |
|
|
|
|
|
|