|
|
import os |
|
|
import json |
|
|
import base64 |
|
|
import uuid |
|
|
import time |
|
|
import re |
|
|
from io import BytesIO |
|
|
import concurrent.futures |
|
|
from tqdm import tqdm |
|
|
import threading |
|
|
import itertools |
|
|
import traceback |
|
|
|
|
|
import numpy as np |
|
|
import soundfile as sf |
|
|
from zhipuai import ZhipuAI |
|
|
|
|
|
|
|
|
try: |
|
|
from zhipuai.core._errors import APIStatusError |
|
|
except ImportError: |
|
|
|
|
|
|
|
|
|
|
|
print("Warning: zhipuai.core._errors.APIStatusError not found. Using generic Exception for status errors.") |
|
|
class APIStatusError(Exception): |
|
|
def __init__(self, message, status_code=None, body=None): |
|
|
super().__init__(message) |
|
|
self.status_code = status_code |
|
|
self.body = body |
|
|
self.message = message |
|
|
|
|
|
from datasets import load_from_disk, Dataset |
|
|
from dotenv import load_dotenv |
|
|
|
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
|
|
|
GLM_MODEL_NAME = "glm-4-voice" |
|
|
|
|
|
|
|
|
ZHIPUAI_API_KEYS = [ |
|
|
"14a67189b8bc4ee489e83b6247c36d0e.AIPUNrII50wREvsh", |
|
|
"72120787822c4123a9654965ff90e4e6.JS1nuey9MncQscPa", |
|
|
"d41b3b5bb49f4c8680b3836e7fc49bbf.u0jGxYc5sYPeRr5p", |
|
|
"bc9bccd6ddd145fc844a014521c26868.JwsZXHzA3l32dDwz", |
|
|
"0e5a05d709794737923ebd122e07d491.sL67ALh6BiLYaaGW", |
|
|
"db87c1fda8af4eb8b505f36e791d700d.w5M0Q3ZssT55tvlW", |
|
|
"1594ac60fbca4973809f4da425238e0c.ZMMfchqbok992Dmu", |
|
|
"469c0fa3b14e4913b1d14bc5d6f0c858.0KdQjFqdi66VPMnb", |
|
|
"b9b538bb0e134438bacaf922b023d1fd.sogFUUp57UJ8YSd6", |
|
|
"50bb382993a345cfa35833fc89caaa52.oR921jSW8iwzCV22", |
|
|
"44512bbede5940f7964db7694bfc04df.yhDEQyPOXQCqh1Mn", |
|
|
"99aba409b55c432696b9d5f1ff565d30.GmfRNngBOo8qDUbf" |
|
|
] |
|
|
|
|
|
if not ZHIPUAI_API_KEYS: |
|
|
print("FATAL: No ZHIPUAI_API_KEYS provided in the list.") |
|
|
exit(1) |
|
|
|
|
|
|
|
|
unique_keys = list(dict.fromkeys(ZHIPUAI_API_KEYS)) |
|
|
if len(unique_keys) != len(ZHIPUAI_API_KEYS): |
|
|
print(f"Warning: Duplicate API keys found and removed. Using {len(unique_keys)} unique keys.") |
|
|
ZHIPUAI_API_KEYS = unique_keys |
|
|
|
|
|
key_cycler = itertools.cycle(ZHIPUAI_API_KEYS) |
|
|
key_lock = threading.Lock() |
|
|
disabled_keys = set() |
|
|
|
|
|
class AllKeysDisabledError(Exception): |
|
|
"""Custom exception raised when all API keys are disabled.""" |
|
|
pass |
|
|
|
|
|
def get_next_active_key(): |
|
|
""" |
|
|
Thread-safely gets the next API key from the cycle, skipping disabled keys. |
|
|
Raises AllKeysDisabledError if all keys are disabled. |
|
|
(User's Original Logic) |
|
|
""" |
|
|
with key_lock: |
|
|
initial_key_count = len(ZHIPUAI_API_KEYS) |
|
|
checked_count = 0 |
|
|
while checked_count < initial_key_count: |
|
|
potential_key = next(key_cycler) |
|
|
if potential_key not in disabled_keys: |
|
|
return potential_key |
|
|
checked_count += 1 |
|
|
|
|
|
if checked_count > initial_key_count * 2: |
|
|
print("Warning: Potential issue in get_next_active_key cycle detection.") |
|
|
break |
|
|
|
|
|
if len(disabled_keys) == initial_key_count: |
|
|
raise AllKeysDisabledError("All API keys have been disabled.") |
|
|
else: |
|
|
|
|
|
|
|
|
print(f"Warning: Could not find an active key after checking {checked_count}. Disabled: {len(disabled_keys)}/{initial_key_count}") |
|
|
raise RuntimeError("Failed to find an active API key.") |
|
|
|
|
|
|
|
|
|
|
|
INPUT_DATASET_DIR = "/root/autodl-tmp/audio-r1/r1-a/dataset/preference_sampling_tasks" |
|
|
OUTPUT_DATASET_DIR = "/root/autodl-tmp/audio-r1/r1-a/dataset/preference_tasks_with_glm" |
|
|
|
|
|
|
|
|
OUTPUT_AUDIO_ROOT_DIR = "/root/autodl-tmp/audio-r1/r1-a/generated_audio/glm_voice" |
|
|
OUTPUT_AUDIO_FORMAT = "wav" |
|
|
OUTPUT_AUDIO_SAMPLERATE = 44100 |
|
|
|
|
|
|
|
|
API_RETRY_DELAY = 5 |
|
|
API_MAX_RETRIES = 3 |
|
|
MAX_WORKERS = 10 |
|
|
|
|
|
|
|
|
def encode_audio_base64(audio_path): |
|
|
|
|
|
if not audio_path or not os.path.exists(audio_path): |
|
|
print(f"Warning: Input audio file not found or path is empty: {audio_path}") |
|
|
return None |
|
|
try: |
|
|
with open(audio_path, "rb") as audio_file: |
|
|
return base64.b64encode(audio_file.read()).decode("utf-8") |
|
|
except Exception as e: |
|
|
print(f"Error encoding audio file {audio_path}: {e}") |
|
|
return None |
|
|
|
|
|
def parse_ultra_history(history_str): |
|
|
|
|
|
messages = [] |
|
|
pattern = re.compile(r"\[(USER|ASSISTANT)\]\s*([\s\S]*?)(?=\s*\[(?:USER|ASSISTANT)\]|$)") |
|
|
matches = pattern.findall(history_str) |
|
|
if not matches: |
|
|
return [] |
|
|
for role_tag, content in matches: |
|
|
role = role_tag.lower() |
|
|
cleaned_content = content.strip() |
|
|
if cleaned_content: |
|
|
messages.append({"role": role, "content": cleaned_content}) |
|
|
return messages |
|
|
|
|
|
|
|
|
def call_glm_voice_api_worker(task_info): |
|
|
""" |
|
|
Worker function to call GLM Voice API, handling key disabling for error 1113, |
|
|
and flattening history into the user prompt with clear markers. |
|
|
(Incorporates Method 2 flattening into user's worker structure) |
|
|
""" |
|
|
row_idx = task_info["row_idx"] |
|
|
slot_idx = task_info["slot_idx"] |
|
|
current_api_key = task_info["api_key"] |
|
|
history_messages = task_info["history_messages"] |
|
|
prompt_text = task_info["prompt_text"] |
|
|
question_audio_path = task_info["question_audio_path"] |
|
|
output_audio_filepath = task_info["output_audio_filepath"] |
|
|
|
|
|
retries = 0 |
|
|
local_glm_client = None |
|
|
|
|
|
while retries < API_MAX_RETRIES: |
|
|
|
|
|
if local_glm_client is None or getattr(local_glm_client, 'api_key', None) != current_api_key: |
|
|
try: |
|
|
with key_lock: |
|
|
if current_api_key in disabled_keys: |
|
|
print(f"Info (Row {row_idx}, Slot {slot_idx}): Assigned key ...{current_api_key[-6:]} was disabled before use, getting new key.") |
|
|
current_api_key = get_next_active_key() |
|
|
task_info["api_key"] = current_api_key |
|
|
print(f" [Thread-{threading.get_ident()}] Initializing client for Row {row_idx}, Slot {slot_idx} (Key: ...{current_api_key[-6:]})") |
|
|
local_glm_client = ZhipuAI(api_key=current_api_key) |
|
|
except AllKeysDisabledError: |
|
|
print(f"FATAL (Row {row_idx}, Slot {slot_idx}): All API keys are disabled. Cannot proceed with task.") |
|
|
return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": "[ERROR: All Keys Disabled]", "saved_audio_path": None} |
|
|
except Exception as client_init_e: |
|
|
print(f"Error (Row {row_idx}, Slot {slot_idx}): Failed to initialize ZhipuAI client with key ...{current_api_key[-6:]}: {client_init_e}") |
|
|
retries += 1 |
|
|
time.sleep(API_RETRY_DELAY) |
|
|
continue |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
base64_audio_data = encode_audio_base64(question_audio_path) |
|
|
if not base64_audio_data: |
|
|
|
|
|
print(f"Error (Row {row_idx}, Slot {slot_idx}): Skipping GLM API call - missing input audio.") |
|
|
return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": "[ERROR: Missing input audio]", "saved_audio_path": None} |
|
|
input_audio_format = os.path.splitext(question_audio_path)[1].lstrip('.') or 'wav' |
|
|
|
|
|
|
|
|
|
|
|
text_parts = [] |
|
|
if history_messages: |
|
|
print(f" (Row {row_idx}, Slot {slot_idx}) Flattening history ({len(history_messages)} turns) into prompt.") |
|
|
text_parts.append("--- Start of Conversation History ---") |
|
|
for msg in history_messages: |
|
|
role_tag = "[User]" if msg['role'] == 'user' else "[Assistant]" |
|
|
|
|
|
content_str = str(msg.get('content', '')).strip() |
|
|
if content_str: |
|
|
text_parts.append(f"{role_tag}: {content_str}") |
|
|
text_parts.append("--- End of Conversation History ---") |
|
|
text_parts.append("\n--- Current Task ---") |
|
|
|
|
|
text_parts.append("Based on the conversation history above and the accompanying audio input, please respond to the following request:") |
|
|
else: |
|
|
|
|
|
print(f" (Row {row_idx}, Slot {slot_idx}) No history found. Using prompt directly.") |
|
|
text_parts.append("--- Current Task ---") |
|
|
text_parts.append("Please respond to the following request based on the accompanying audio input:") |
|
|
|
|
|
|
|
|
if prompt_text: |
|
|
text_parts.append(prompt_text.strip()) |
|
|
|
|
|
combined_user_text = "\n".join(text_parts) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
user_content_list = [ |
|
|
{"type": "text", "text": combined_user_text}, |
|
|
{"type": "input_audio", "input_audio": {"data": base64_audio_data, "format": input_audio_format}} |
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
messages = [{"role": "user", "content": user_content_list}] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
response = local_glm_client.chat.completions.create( |
|
|
model=GLM_MODEL_NAME, |
|
|
messages=messages, |
|
|
stream=False |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
if response and response.choices: |
|
|
message = response.choices[0].message |
|
|
collected_text = message.content |
|
|
audio_info = getattr(message, 'audio', None) |
|
|
if audio_info and 'data' in audio_info: |
|
|
audio_base64_string = audio_info['data'] |
|
|
try: |
|
|
decoded_data = base64.b64decode(audio_base64_string) |
|
|
if len(decoded_data) == 0: |
|
|
print(f"Warning (Row {row_idx}, Slot {slot_idx}, Key ...{current_api_key[-6:]}): GLM returned empty audio data.") |
|
|
return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": collected_text.strip(), "saved_audio_path": None} |
|
|
|
|
|
os.makedirs(os.path.dirname(output_audio_filepath), exist_ok=True) |
|
|
|
|
|
with BytesIO(decoded_data) as bio: |
|
|
try: |
|
|
audio_data, samplerate = sf.read(bio, dtype='int16') |
|
|
except Exception: |
|
|
bio.seek(0) |
|
|
try: |
|
|
audio_data_float, samplerate = sf.read(bio, dtype='float32') |
|
|
|
|
|
audio_data = (audio_data_float * 32767).astype(np.int16) |
|
|
except Exception as sf_read_err_float: |
|
|
print(f"Error (Row {row_idx}, Slot {slot_idx}, Key ...{current_api_key[-6:]}): Soundfile failed to read audio data: {sf_read_err_float}") |
|
|
|
|
|
return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": collected_text.strip(), "saved_audio_path": None} |
|
|
|
|
|
|
|
|
write_samplerate = samplerate if samplerate > 0 else OUTPUT_AUDIO_SAMPLERATE |
|
|
sf.write(output_audio_filepath, audio_data, write_samplerate) |
|
|
|
|
|
|
|
|
return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": collected_text.strip(), "saved_audio_path": output_audio_filepath} |
|
|
|
|
|
except base64.binascii.Error as b64_e: |
|
|
print(f"Error (Row {row_idx}, Slot {slot_idx}, Key ...{current_api_key[-6:]}): GLM b64 decode failed: {b64_e}") |
|
|
|
|
|
return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": collected_text.strip(), "saved_audio_path": None} |
|
|
except Exception as e: |
|
|
print(f"Error (Row {row_idx}, Slot {slot_idx}, Key ...{current_api_key[-6:]}): Saving GLM audio failed: {e}") |
|
|
|
|
|
return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": collected_text.strip(), "saved_audio_path": None} |
|
|
else: |
|
|
print(f"Warning (Row {row_idx}, Slot {slot_idx}, Key ...{current_api_key[-6:]}): No audio data in GLM response.") |
|
|
return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": collected_text.strip(), "saved_audio_path": None} |
|
|
else: |
|
|
print(f"Error (Row {row_idx}, Slot {slot_idx}, Key ...{current_api_key[-6:]}): Invalid/empty GLM API response. Response: {response}") |
|
|
|
|
|
retries += 1 |
|
|
time.sleep(API_RETRY_DELAY) |
|
|
continue |
|
|
|
|
|
|
|
|
except APIStatusError as e: |
|
|
|
|
|
print(f"Error (Row {row_idx}, Slot {slot_idx}, Key ...{current_api_key[-6:]}): APIStatusError Encountered") |
|
|
print(f" Status Code: {getattr(e, 'status_code', 'N/A')}") |
|
|
error_details = getattr(e, 'body', getattr(e, 'message', str(e))) |
|
|
print(f" Error Details: {error_details}") |
|
|
|
|
|
|
|
|
|
|
|
is_overdue_error = False |
|
|
status_code = getattr(e, 'status_code', None) |
|
|
|
|
|
if status_code == 429 or (status_code == 400 and '1113' in str(error_details)): |
|
|
try: |
|
|
error_body = {} |
|
|
|
|
|
if isinstance(error_details, (str, bytes)) and error_details.strip().startswith('{'): |
|
|
error_body = json.loads(error_details) |
|
|
elif isinstance(error_details, dict): |
|
|
error_body = error_details |
|
|
|
|
|
if isinstance(error_body, dict) and str(error_body.get("error", {}).get("code", "")) == "1113": |
|
|
is_overdue_error = True |
|
|
except (json.JSONDecodeError, AttributeError): |
|
|
|
|
|
pass |
|
|
except Exception as parse_err: |
|
|
print(f"Warning: Error parsing API error body: {parse_err}") |
|
|
|
|
|
if is_overdue_error: |
|
|
key_to_disable = current_api_key |
|
|
print(f"Error (Row {row_idx}, Slot {slot_idx}): Account overdue (1113) for Key ...{key_to_disable[-6:]}. Disabling key.") |
|
|
with key_lock: |
|
|
disabled_keys.add(key_to_disable) |
|
|
print(f" Disabled keys count: {len(disabled_keys)}/{len(ZHIPUAI_API_KEYS)}") |
|
|
|
|
|
|
|
|
try: |
|
|
current_api_key = get_next_active_key() |
|
|
print(f" (Row {row_idx}, Slot {slot_idx}) Switched to new key ...{current_api_key[-6:]} for next attempt.") |
|
|
local_glm_client = None |
|
|
continue |
|
|
except AllKeysDisabledError: |
|
|
print(f"FATAL (Row {row_idx}, Slot {slot_idx}): All API keys are disabled after key ...{key_to_disable[-6:]} failed. Cannot retry task.") |
|
|
|
|
|
return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": "[ERROR: All Keys Disabled]", "saved_audio_path": None} |
|
|
|
|
|
else: |
|
|
|
|
|
retries += 1 |
|
|
print(f"Error (Row {row_idx}, Slot {slot_idx}, Key ...{current_api_key[-6:]}): GLM API Call Attempt {retries}/{API_MAX_RETRIES} failed: HTTP {status_code}, {error_details}") |
|
|
if retries < API_MAX_RETRIES: |
|
|
time.sleep(API_RETRY_DELAY) |
|
|
|
|
|
continue |
|
|
else: |
|
|
print(f"Error (Row {row_idx}, Slot {slot_idx}, Key ...{current_api_key[-6:]}): Max retries reached after API error.") |
|
|
|
|
|
return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": "[API ERROR: Max retries after status error]", "saved_audio_path": None} |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
retries += 1 |
|
|
print(f"Error (Row {row_idx}, Slot {slot_idx}, Key ...{current_api_key[-6:]}): Unexpected Error Attempt {retries}/{API_MAX_RETRIES}: {type(e).__name__} - {e}") |
|
|
print(traceback.format_exc()) |
|
|
if retries < API_MAX_RETRIES: |
|
|
time.sleep(API_RETRY_DELAY) |
|
|
continue |
|
|
else: |
|
|
print(f"Error (Row {row_idx}, Slot {slot_idx}, Key ...{current_api_key[-6:]}): Max retries reached after unexpected error.") |
|
|
|
|
|
return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": "[API ERROR: Max retries after unexpected error]", "saved_audio_path": None} |
|
|
|
|
|
|
|
|
print(f"Error (Row {row_idx}, Slot {slot_idx}): Task failed after {API_MAX_RETRIES} attempts (may include key switches).") |
|
|
return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": "[API ERROR: Max retries reached]", "saved_audio_path": None} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("Loading dataset...") |
|
|
try: |
|
|
dataset = load_from_disk(INPUT_DATASET_DIR) |
|
|
print(f"Dataset loaded successfully with {len(dataset)} rows from {INPUT_DATASET_DIR}.") |
|
|
except Exception as e: |
|
|
print(f"FATAL: Error loading dataset from {INPUT_DATASET_DIR}: {e}") |
|
|
exit(1) |
|
|
|
|
|
os.makedirs(OUTPUT_AUDIO_ROOT_DIR, exist_ok=True) |
|
|
|
|
|
|
|
|
print("Pre-calculating GLM tasks and assigning initial API keys...") |
|
|
tasks_to_process = [] |
|
|
original_data = list(dataset) |
|
|
initial_keys_available = True |
|
|
|
|
|
for idx, row in enumerate(tqdm(original_data, desc="Scanning dataset for GLM tasks")): |
|
|
if not initial_keys_available: |
|
|
|
|
|
print("Stopping task scanning as no active keys are available.") |
|
|
break |
|
|
|
|
|
for i in range(1, 4): |
|
|
model_key = f"model_{i}" |
|
|
response_text_key = f"response_text_{i}" |
|
|
prompt_text_key = f"prompt_text_{i}" |
|
|
model_assigned = row.get(model_key) |
|
|
|
|
|
response_text_exists = row.get(response_text_key) is not None and str(row.get(response_text_key)).strip() != "" |
|
|
|
|
|
|
|
|
if model_assigned == "glm_voice" and not response_text_exists: |
|
|
question_audio_path = row.get('question_audio') |
|
|
|
|
|
if not question_audio_path or not os.path.exists(question_audio_path): |
|
|
print(f"Warning (Row {idx}, Slot {i}): Skipping GLM task - Missing or invalid 'question_audio' path: {question_audio_path}") |
|
|
continue |
|
|
|
|
|
|
|
|
try: |
|
|
assigned_key = get_next_active_key() |
|
|
except AllKeysDisabledError: |
|
|
print("FATAL: All API keys are disabled during initial task scanning. Cannot proceed.") |
|
|
initial_keys_available = False |
|
|
break |
|
|
except Exception as key_err: |
|
|
print(f"FATAL: Error getting initial API key: {key_err}. Stopping.") |
|
|
initial_keys_available = False |
|
|
break |
|
|
|
|
|
|
|
|
metadata_str = row.get('metadata', "{}") |
|
|
source_dataset = row.get('source_dataset') |
|
|
metadata = {} |
|
|
try: |
|
|
|
|
|
if metadata_str and isinstance(metadata_str, str): metadata = json.loads(metadata_str) |
|
|
elif isinstance(metadata_str, dict): metadata = metadata_str |
|
|
except json.JSONDecodeError: |
|
|
print(f"Warning (Row {idx}): Could not parse metadata string: {metadata_str[:100]}...") |
|
|
pass |
|
|
|
|
|
|
|
|
history_messages = [] |
|
|
if source_dataset == 'ultra': |
|
|
history_str = metadata.get('history', '') |
|
|
if history_str: history_messages = parse_ultra_history(history_str) |
|
|
|
|
|
unique_id = str(uuid.uuid4()).replace("-", "") |
|
|
output_audio_filename = f"glm_r{idx}_s{i}_{unique_id}.{OUTPUT_AUDIO_FORMAT}" |
|
|
output_audio_filepath = os.path.join(OUTPUT_AUDIO_ROOT_DIR, output_audio_filename) |
|
|
|
|
|
task_info = { |
|
|
"row_idx": idx, |
|
|
"slot_idx": i, |
|
|
"api_key": assigned_key, |
|
|
"history_messages": history_messages, |
|
|
"prompt_text": row.get(prompt_text_key, ""), |
|
|
"question_audio_path": question_audio_path, |
|
|
"output_audio_filepath": output_audio_filepath, |
|
|
} |
|
|
tasks_to_process.append(task_info) |
|
|
|
|
|
break |
|
|
|
|
|
if not initial_keys_available: break |
|
|
|
|
|
total_tasks = len(tasks_to_process) |
|
|
if total_tasks == 0: |
|
|
if not initial_keys_available: |
|
|
print("No tasks processed because all initial keys were disabled.") |
|
|
else: |
|
|
print("No GLM Voice tasks found needing processing.") |
|
|
exit(0) |
|
|
|
|
|
print(f"Found {total_tasks} GLM Voice tasks to process using initially {len(ZHIPUAI_API_KEYS)} API keys.") |
|
|
if len(disabled_keys) > 0: |
|
|
print(f"Note: {len(disabled_keys)} keys already marked as disabled (should not happen at this stage).") |
|
|
|
|
|
|
|
|
|
|
|
print(f"Starting GLM processing with up to {MAX_WORKERS} worker threads...") |
|
|
start_total_time = time.time() |
|
|
results = {} |
|
|
tasks_completed = 0 |
|
|
tasks_failed = 0 |
|
|
executor_shutdown = False |
|
|
|
|
|
|
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor: |
|
|
|
|
|
future_to_task = {executor.submit(call_glm_voice_api_worker, task): task for task in tasks_to_process} |
|
|
|
|
|
for future in tqdm(concurrent.futures.as_completed(future_to_task), total=total_tasks, desc="Processing GLM tasks"): |
|
|
task = future_to_task[future] |
|
|
row_idx = task["row_idx"] |
|
|
slot_idx = task["slot_idx"] |
|
|
try: |
|
|
result = future.result() |
|
|
results[(row_idx, slot_idx)] = result |
|
|
|
|
|
|
|
|
if result["response_text"] == "[ERROR: All Keys Disabled]" and not executor_shutdown: |
|
|
print("\n--- CRITICAL: All Keys Disabled detected during execution. Stopping submission of new tasks. ---") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
executor_shutdown = True |
|
|
tasks_failed += 1 |
|
|
|
|
|
elif result["saved_audio_path"] is None or "[ERROR" in result["response_text"]: |
|
|
tasks_failed += 1 |
|
|
tasks_completed += 1 |
|
|
|
|
|
except Exception as exc: |
|
|
print(f"Error (Row {row_idx}, Slot {slot_idx}): GLM Task generated an unhandled exception: {exc}") |
|
|
print(traceback.format_exc()) |
|
|
|
|
|
results[(row_idx, slot_idx)] = {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": f"[ERROR: Worker Crash - {type(exc).__name__}]", "saved_audio_path": None} |
|
|
tasks_failed += 1 |
|
|
tasks_completed += 1 |
|
|
|
|
|
|
|
|
|
|
|
end_total_time = time.time() |
|
|
print("\n--- GLM Processing Complete ---") |
|
|
print(f"Total GLM tasks attempted: {tasks_completed} (Succeeded: {tasks_completed - tasks_failed}, Failed: {tasks_failed})") |
|
|
print(f"Final disabled key count: {len(disabled_keys)}/{len(ZHIPUAI_API_KEYS)}") |
|
|
print(f"Total GLM processing time: {(end_total_time - start_total_time)/60:.2f} minutes") |
|
|
|
|
|
|
|
|
|
|
|
print("Merging GLM results...") |
|
|
updated_data = original_data |
|
|
for (row_idx, slot_idx), result in tqdm(results.items(), desc="Merging GLM results"): |
|
|
response_text_key = f"response_text_{slot_idx}" |
|
|
response_audio_key = f"response_audio_path_{slot_idx}" |
|
|
|
|
|
if 0 <= row_idx < len(updated_data): |
|
|
|
|
|
if isinstance(updated_data[row_idx], dict): |
|
|
updated_data[row_idx][response_text_key] = result["response_text"] |
|
|
updated_data[row_idx][response_audio_key] = result["saved_audio_path"] |
|
|
else: |
|
|
print(f"Warning: Item at index {row_idx} is not a dictionary. Skipping merge for Slot {slot_idx}.") |
|
|
else: |
|
|
print(f"Warning: Invalid row index {row_idx} encountered during GLM result merge.") |
|
|
|
|
|
|
|
|
if updated_data: |
|
|
print(f"\nSaving updated dataset with GLM results to {OUTPUT_DATASET_DIR}...") |
|
|
try: |
|
|
|
|
|
updated_dataset = Dataset.from_list(updated_data, features=dataset.features if dataset else None) |
|
|
updated_dataset.save_to_disk(OUTPUT_DATASET_DIR) |
|
|
print("Updated dataset saved successfully.") |
|
|
except Exception as final_save_e: |
|
|
print(f"Error saving final dataset using datasets lib: {final_save_e}") |
|
|
print(f"Final disabled key count at save: {len(disabled_keys)}/{len(ZHIPUAI_API_KEYS)}") |
|
|
print("Attempting to save as JSON lines as fallback...") |
|
|
|
|
|
output_jsonl_path = OUTPUT_DATASET_DIR.rstrip('/') + ".jsonl" |
|
|
try: |
|
|
with open(output_jsonl_path, 'w', encoding='utf-8') as f: |
|
|
for item in updated_data: |
|
|
|
|
|
serializable_item = {} |
|
|
for k, v in item.items(): |
|
|
if isinstance(v, (str, int, float, bool, list, dict)) or v is None: |
|
|
serializable_item[k] = v |
|
|
elif isinstance(v, np.ndarray): |
|
|
serializable_item[k] = v.tolist() |
|
|
else: |
|
|
serializable_item[k] = str(v) |
|
|
f.write(json.dumps(serializable_item, ensure_ascii=False) + '\n') |
|
|
print(f"Fallback save successful to {output_jsonl_path}") |
|
|
except Exception as json_save_e: |
|
|
print(f"Error saving as JSON lines: {json_save_e}") |
|
|
|
|
|
else: |
|
|
print("No data was available to save (potentially all keys disabled early or no tasks processed).") |