1f's picture
Add files using upload-large-folder tool
ff53362 verified
import os
import json
import time
import re # For parsing history
import uuid # For generating unique filenames
import torch # Kimi might return tensors
import soundfile as sf # For saving Kimi audio output
import sys
from datasets import load_from_disk, Dataset, Features, Audio, Value
from dotenv import load_dotenv
import datetime # For ETA formatting
from tqdm import tqdm # Import tqdm
import traceback # For detailed error printing
# --- Kimi-Audio Project Path Setup ---
# <--- *** IMPORTANT: Update this path to the PARENT directory containing the 'kimia_infer' folder *** --->
kimia_project_parent_dir = "/home/chenyifu/audio-r1/r1-a/response_generation/Kimi-Audio"
# Check if the path exists and add it to sys.path
if os.path.isdir(kimia_project_parent_dir):
if kimia_project_parent_dir not in sys.path:
sys.path.insert(0, kimia_project_parent_dir)
print(f"Added '{kimia_project_parent_dir}' to Python path.")
# Try importing KimiAudio only after potentially adding the path
try:
from kimia_infer.api.kimia import KimiAudio # Kimi model class
except ImportError as import_err:
print(f"Error: Could not import KimiAudio from '{kimia_project_parent_dir}'.")
print(f"ImportError: {import_err}")
print("Please ensure the 'kimia_infer' directory exists within the specified path and check dependencies.")
exit(1)
else:
print(f"Error: Kimi project parent directory not found: '{kimia_project_parent_dir}'")
print("Please update the 'kimia_project_parent_dir' variable in the script.")
exit(1)
# --- Configuration ---
load_dotenv() # Load environment variables if needed (e.g., API keys, though not typical for local Kimi)
# 1. Model & Tokenizer Setup (Kimi Specific)
KIMI_MODEL_NAME = "kimi_audio" # Identifier used in your dataset's model_N columns
KIMI_MODEL_PATH = "/home/chenyifu/audio-r1/r1-a/response_generation/Kimi-Audio/checkpoint/Kimi-Audio-7B-Instruct" # Path to your Kimi model checkpoint directory
# KIMI_DEVICE = 'cuda' # KimiAudio class likely handles device selection based on availability. Verify its internal logic if issues arise.
# KIMI_DTYPE = torch.bfloat16 # KimiAudio likely handles dtype internally.
# 2. Dataset Paths
INPUT_DATASET_DIR = "/home/chenyifu/audio-r1/r1-a/dataset/preference_sampling_tasks" # Original source
OUTPUT_DATASET_DIR = "/home/chenyifu/audio-r1/r1-a/dataset/preference_tasks_with_kimi" # Where Kimi processed data is saved/resumed from
# 3. Output Audio Configuration (Kimi Specific)
OUTPUT_AUDIO_ROOT_DIR = "/home/chenyifu/audio-r1/r1-a/generated_audio/kimi" # Where Kimi generated audio files are saved
OUTPUT_AUDIO_FORMAT = "wav"
OUTPUT_AUDIO_SAMPLERATE = 24000 # Kimi example uses 24kHz output. Confirm this matches your model's expected/native output SR.
# 4. Kimi Call Settings (Based on example, adjust as needed)
KIMI_SAMPLING_PARAMS = {
"audio_temperature": 0.8,
"audio_top_k": 10,
"text_temperature": 0.0, # 0.0 for deterministic text, increase for more variety
"text_top_k": 5, # Relevant if text_temperature > 0
"audio_repetition_penalty": 1.0,
"audio_repetition_window_size": 64,
"text_repetition_penalty": 1.0,
"text_repetition_window_size": 16,
# "max_new_tokens": 128 # Add if needed and supported by KimiAudio.generate
}
KIMI_OUTPUT_TYPE = "both" # Generate both audio and text
# 5. Periodic Save Settings
SAVE_EVERY_N_SAMPLES = 50 # Save after processing this many samples
# --- Helper Functions ---
def format_time(seconds):
"""Formats seconds into a human-readable string H:MM:SS"""
if seconds < 0:
return "N/A"
return str(datetime.timedelta(seconds=int(seconds)))
# REMOVED load_audio_minicpm - Kimi takes the path directly
def parse_ultra_history(history_str):
"""Parses the specific history string format from ultra metadata for Kimi."""
messages = []
# Relaxed pattern to capture content even if tags are slightly off or whitespace varies
pattern = re.compile(r"\[\s*(USER|ASSISTANT)\s*\]\s*([\s\S]*?)(?=\s*\[\s*(?:USER|ASSISTANT)\s*\]|$)")
matches = pattern.findall(history_str)
if not matches and history_str and history_str.strip():
# Simple fallback if standard pattern fails but there's content
if history_str.lower().startswith("user:") or history_str.lower().startswith("[user]"):
role = "user"
content = re.sub(r"^(user:|\[user\])\s*", "", history_str, flags=re.IGNORECASE).strip()
if content: messages.append({"role": role, "message_type": "text", "content": content}) # Add Kimi message_type
elif history_str.lower().startswith("assistant:") or history_str.lower().startswith("[assistant]"):
role = "assistant"
content = re.sub(r"^(assistant:|\[assistant\])\s*", "", history_str, flags=re.IGNORECASE).strip()
if content: messages.append({"role": role, "message_type": "text", "content": content}) # Add Kimi message_type
else:
print(f"Warning: Could not parse history string format: {history_str[:100]}...")
return messages # Return whatever was parsed, even if empty
for role_tag, content in matches:
role = role_tag.strip().lower()
cleaned_content = content.strip()
if cleaned_content:
# IMPORTANT: Add message_type='text' for Kimi history
messages.append({"role": role, "message_type": "text", "content": cleaned_content})
return messages
# --- Kimi Model Interaction Function ---
def call_kimi_model(model, messages_input, sampling_params, output_audio_filepath, output_sample_rate):
"""Calls the Kimi-Audio model, saves audio, returns text and audio path."""
try:
# 1. Ensure Output Directory Exists
os.makedirs(os.path.dirname(output_audio_filepath), exist_ok=True)
# 2. Call Kimi's Generate Function
wav_output, text_output = model.generate(
messages_input,
**sampling_params,
output_type=KIMI_OUTPUT_TYPE # Use 'both'
)
# 3. Process and Save Audio Output
saved_audio_path = None
if wav_output is not None and isinstance(wav_output, torch.Tensor) and wav_output.numel() > 0: # Check if tensor is not empty
try:
# Ensure tensor is on CPU, reshape (if needed, often view(-1)), convert to numpy
# Check KimiAudio output format - might already be 1D or need specific shape
audio_data = wav_output.detach().cpu().view(-1).numpy()
# Ensure data is float32 or int16 as supported by soundfile/WAV
if audio_data.dtype != 'float32':
# Attempt conversion, potentially scale if it's int
# print(f" Info: Converting Kimi audio output from {audio_data.dtype} to float32 for saving.")
if np.issubdtype(audio_data.dtype, np.integer):
# Scale integer types to [-1, 1] float range if necessary
# Example: if int16 -> audio_data = audio_data.astype(np.float32) / 32768.0
# Adjust scaling based on the actual integer range if known
audio_data = audio_data.astype(np.float32) # Simplest conversion, might need scaling
else:
audio_data = audio_data.astype(np.float32)
sf.write(output_audio_filepath, audio_data, output_sample_rate)
# Check if file was actually created and has size
if os.path.exists(output_audio_filepath) and os.path.getsize(output_audio_filepath) > 100: # Check for a reasonable size threshold
saved_audio_path = output_audio_filepath
else:
print(f" Error: Kimi generate finished but output audio file seems empty or too small at {output_audio_filepath}")
if os.path.exists(output_audio_filepath):
try: os.remove(output_audio_filepath)
except OSError as rm_err: print(f" Warning: Could not remove empty/small file {output_audio_filepath}: {rm_err}")
except ImportError:
print("Error: NumPy library not found. Please install it (`pip install numpy`)")
return "[ERROR: NumPy Missing]", None # Indicate failure clearly
except Exception as sf_err:
print(f" Error saving Kimi audio output to {output_audio_filepath}: {sf_err}")
traceback.print_exc()
if os.path.exists(output_audio_filepath):
try: os.remove(output_audio_filepath)
except OSError as rm_err: print(f" Warning: Could not remove potentially corrupt file {output_audio_filepath}: {rm_err}")
elif wav_output is None:
print(" Warning: Kimi model did not return an audio tensor (wav_output is None).")
elif isinstance(wav_output, torch.Tensor) and wav_output.numel() == 0:
print(" Warning: Kimi model returned an empty audio tensor.")
else:
print(f" Warning: Kimi model returned unexpected audio output type: {type(wav_output)}. Expected torch.Tensor.")
# 4. Process Text Output
response_text_cleaned = ""
if isinstance(text_output, str):
response_text_cleaned = text_output.strip()
elif text_output is not None:
response_text_cleaned = str(text_output).strip() # Convert just in case
else:
# If text is None but audio might exist, use a specific marker
if saved_audio_path:
response_text_cleaned = "[Audio Generated, No Text Output]"
else:
response_text_cleaned = "[ERROR: No Text Output]"
# Return text (even if audio failed) and the path (or None)
return response_text_cleaned, saved_audio_path
except Exception as e:
print(f"\n --- Error during Kimi model call ---")
# Avoid printing potentially huge message list directly
first_message = messages_input[0] if messages_input else "N/A"
last_message_content = messages_input[-1]['content'] if messages_input else "N/A"
if isinstance(last_message_content, str) and len(last_message_content) > 100 :
last_message_preview = last_message_content[:100] + "..."
else:
last_message_preview = last_message_content
print(f" Input Messages Info: Count={len(messages_input)}, First={first_message}, Last Content Preview='{last_message_preview}'")
print(f" Exception Type: {type(e).__name__}")
print(f" Error Details: {e}")
print(" Traceback:")
traceback.print_exc()
print(" --- End Error Details ---")
# Attempt cleanup of potentially incomplete output file
if 'output_audio_filepath' in locals() and os.path.exists(output_audio_filepath):
try:
os.remove(output_audio_filepath)
except OSError as rm_err:
print(f" Warning: Could not remove file {output_audio_filepath} after error: {rm_err}")
# Return clear error markers
return "[ERROR: Kimi Model Call Failed]", None
# --- Dataset Saving Function (Modified for Kimi context) ---
def save_checkpoint(data_list, features, output_dir, fallback_dir=None):
"""Saves the current state of the data list as a Hugging Face Dataset."""
if not data_list:
print("\nSkipping checkpoint save: data list is empty.")
return
print(f"\nSaving checkpoint with {len(data_list)} rows to {output_dir}...")
try:
# Ensure the list contains dictionaries
data_to_save = [dict(item) for item in data_list]
# --- Feature Check/Adaptation (Optional but recommended) ---
# Sometimes saving fails if data types changed unexpectedly (e.g., None -> str)
# It's safer to create the Dataset *without* features first, then cast
temp_dataset = Dataset.from_list(data_to_save)
# Now cast to the original features, allowing potential None/type mismatches
# This might raise warnings but is often more robust than direct from_list with features
updated_dataset = temp_dataset.cast(features)
# --- End Feature Check ---
# Ensure output directory exists before saving
os.makedirs(output_dir, exist_ok=True)
updated_dataset.save_to_disk(output_dir)
print("Checkpoint saved successfully.")
except Exception as e:
print(f"Error saving checkpoint dataset using save_to_disk to {output_dir}: {e}")
traceback.print_exc()
if fallback_dir:
# Use Kimi-specific name in fallback path
fallback_path = os.path.join(fallback_dir, f"updated_{KIMI_MODEL_NAME}_data_checkpoint_{int(time.time())}.jsonl")
print(f"Attempting to save data as JSON Lines fallback to: {fallback_path}")
try:
os.makedirs(fallback_dir, exist_ok=True)
with open(fallback_path, 'w', encoding='utf-8') as f:
# Reuse data_to_save which is already list of dicts
for item in data_to_save:
# Ensure all values are serializable
serializable_item = {}
for k, v in item.items():
if isinstance(v, (datetime.datetime, datetime.date)):
serializable_item[k] = v.isoformat()
elif isinstance(v, bytes):
serializable_item[k] = v.decode('utf-8', errors='ignore')
elif isinstance(v, torch.Tensor): # Handle potential tensors if not caught earlier
print(f" Warning: Found unexpected Tensor for key '{k}' in fallback save. Converting to list.")
serializable_item[k] = v.tolist()
elif not isinstance(v, (str, int, float, bool, list, dict, type(None))):
print(f" Warning: Converting non-standard type {type(v)} for key '{k}' to string for JSON fallback.")
serializable_item[k] = str(v)
else:
serializable_item[k] = v
try:
f.write(json.dumps(serializable_item, ensure_ascii=False) + '\n')
except TypeError as json_type_err:
print(f" Skipping row due to JSON serialization error: {json_type_err} in item part: {k}={v}")
print("Fallback JSON Lines checkpoint saved successfully.")
except Exception as json_e:
print(f"Error saving fallback JSON Lines checkpoint: {json_e}")
# =============================================
# --- Main Processing Logic ---
# =============================================
# --- STEP 1: Dataset Loading (Modified for Resumption) ---
print("="*30)
print("STEP 1: Loading Dataset")
print("="*30)
dataset = None
original_features = None # Initialize
# Check if the Kimi-specific output directory exists
if os.path.exists(OUTPUT_DATASET_DIR):
print(f"Found existing Kimi processed dataset directory at: {OUTPUT_DATASET_DIR}")
print("Attempting to load it to resume processing...")
try:
dataset = load_from_disk(OUTPUT_DATASET_DIR)
original_features = dataset.features # Get features from the loaded dataset
print(f"Resumed Kimi dataset loaded successfully with {len(dataset)} rows.")
print(f"Features from resumed dataset: {original_features}")
except Exception as e:
print(f"Warning: Error loading existing Kimi dataset from {OUTPUT_DATASET_DIR}: {e}")
traceback.print_exc()
print("Will attempt to load the original input dataset instead.")
dataset = None # Reset dataset variable
else:
print(f"No existing Kimi processed dataset found at {OUTPUT_DATASET_DIR}.")
print("Will attempt to load the original input dataset.")
# If dataset is still None, load from the original input directory
if dataset is None:
print(f"\nLoading original input dataset from: {INPUT_DATASET_DIR}")
if not os.path.exists(INPUT_DATASET_DIR):
print(f"FATAL: Original input dataset directory not found at {INPUT_DATASET_DIR}")
exit(1)
try:
dataset = load_from_disk(INPUT_DATASET_DIR)
original_features = dataset.features # Get features from the input dataset
print(f"Original input dataset loaded successfully with {len(dataset)} rows.")
print(f"Features from input dataset: {original_features}")
except Exception as e:
print(f"FATAL: Error loading original input dataset from {INPUT_DATASET_DIR}: {e}")
traceback.print_exc()
exit(1)
# --- Ensure dataset and features were loaded ---
if dataset is None or original_features is None:
print("FATAL: Failed to load any dataset. Exiting.")
exit(1)
# --- End Dataset Loading ---
# --- STEP 2: Pre-computation - Identify Kimi Tasks ---
print("\n" + "="*30)
print(f"STEP 2: Identifying '{KIMI_MODEL_NAME}' Tasks to Process")
print("="*30)
pkusafe_tasks_indices = []
other_tasks_indices = []
# Iterate through the loaded dataset structure
for idx, row in enumerate(dataset):
source_dataset = row.get('source_dataset')
processed_in_row = False # Flag to ensure we only pick one Kimi slot per row
for i in range(1, 4): # Check slots 1, 2, 3
model_key = f"model_{i}"
response_text_key = f"response_text_{i}"
# Check if the slot is assigned to Kimi and is NOT yet filled (text response missing)
is_target_model_task = row.get(model_key) == KIMI_MODEL_NAME
is_unfilled = not row.get(response_text_key) # True if None or empty string
if is_target_model_task and is_unfilled and not processed_in_row:
task_info = (idx, i) # Store tuple of (original_row_index, slot_index)
if source_dataset == 'pkusafe':
pkusafe_tasks_indices.append(task_info)
else:
other_tasks_indices.append(task_info)
processed_in_row = True # Mark row as having a task identified
# Combine lists, prioritizing pkusafe
tasks_to_process_indices = pkusafe_tasks_indices + other_tasks_indices
total_tasks_to_process = len(tasks_to_process_indices)
print(f"Found {len(pkusafe_tasks_indices)} 'pkusafe' tasks and {len(other_tasks_indices)} other tasks requiring '{KIMI_MODEL_NAME}' processing in the loaded dataset.")
print(f"Total tasks remaining to process: {total_tasks_to_process}")
if total_tasks_to_process == 0:
print(f"\nNo remaining tasks to process for {KIMI_MODEL_NAME} based on the loaded dataset.")
# Optionally, perform a final save for consistency
# print("Performing a final save to ensure consistency...")
# final_data_list = [dict(row) for row in dataset]
# fallback_save_dir_final = os.path.join(os.path.dirname(OUTPUT_DATASET_DIR), f"{KIMI_MODEL_NAME}_checkpoints_fallback")
# save_checkpoint(final_data_list, original_features, OUTPUT_DATASET_DIR, fallback_save_dir_final)
print("Exiting.")
exit(0)
# --- End Pre-computation Step ---
# --- STEP 3: Load Kimi Model ---
print("\n" + "="*30)
print(f"STEP 3: Loading {KIMI_MODEL_NAME} Model")
print("="*30)
try:
# Load Kimi model using the class imported earlier
model = KimiAudio(model_path=KIMI_MODEL_PATH, load_detokenizer=True) # Assuming detokenizer is needed based on example
print(f"{KIMI_MODEL_NAME} model loaded successfully from {KIMI_MODEL_PATH}.")
except NameError:
print("FATAL: KimiAudio class not defined. Import likely failed earlier.")
exit(1)
except Exception as e:
print(f"Error loading {KIMI_MODEL_NAME} model from {KIMI_MODEL_PATH}: {e}")
traceback.print_exc()
exit(1)
# --- STEP 4: Prepare for Processing ---
print("\n" + "="*30)
print(f"STEP 4: Preparing for {KIMI_MODEL_NAME} Processing")
print("="*30)
# Create output directories if they don't exist
os.makedirs(OUTPUT_AUDIO_ROOT_DIR, exist_ok=True)
os.makedirs(OUTPUT_DATASET_DIR, exist_ok=True)
# Define and create fallback directory for Kimi
fallback_save_dir = os.path.join(os.path.dirname(OUTPUT_DATASET_DIR), f"{KIMI_MODEL_NAME}_checkpoints_fallback")
os.makedirs(fallback_save_dir, exist_ok=True)
print(f"Audio outputs will be saved in: {OUTPUT_AUDIO_ROOT_DIR}")
print(f"Dataset checkpoints will be saved in: {OUTPUT_DATASET_DIR}")
print(f"Fallback checkpoints (JSONL) in: {fallback_save_dir}")
# Create a mutable list of dictionaries from the loaded dataset for updates
updated_data = [dict(row) for row in dataset] # Convert each row to a dictionary
tasks_processed_count = 0 # Count successful completions for average time calculation
start_total_time = time.time()
# --- STEP 5: Start Processing Loop ---
print("\n" + "="*30)
print(f"STEP 5: Starting {KIMI_MODEL_NAME} Processing Loop ({total_tasks_to_process} Tasks)")
print("="*30)
# Use tqdm for the progress bar, iterating over the identified task indices
pbar = tqdm(enumerate(tasks_to_process_indices), total=total_tasks_to_process, desc=f"Processing {KIMI_MODEL_NAME} Tasks")
for loop_idx, (row_idx, slot_i) in pbar:
# Get the row data *from our mutable list* using the original index
row = updated_data[row_idx] # This is already a dictionary
# Set description in tqdm dynamically
pbar.set_description(f"Processing Row {row_idx}, Slot {slot_i}")
prompt_text_key = f"prompt_text_{slot_i}"
response_text_key = f"response_text_{slot_i}"
response_audio_key = f"response_audio_path_{slot_i}"
model_key = f"model_{slot_i}"
# --- Sanity Check ---
if row.get(model_key) != KIMI_MODEL_NAME:
tqdm.write(f" Skipping Row {row_idx}, Slot {slot_i}: Model is '{row.get(model_key)}', not '{KIMI_MODEL_NAME}'.")
continue
if row.get(response_text_key):
tqdm.write(f" Skipping Row {row_idx}, Slot {slot_i}: Already has response text '{str(row.get(response_text_key))[:50]}...'.")
continue
# --- Prepare Kimi Model Inputs ---
prompt_text = row.get(prompt_text_key, "")
question_audio_path = row.get('question_audio')
metadata_str = row.get('metadata', "{}")
source_dataset = row.get('source_dataset')
# Check for essential input audio path validity
if not question_audio_path or not os.path.exists(question_audio_path):
tqdm.write(f" Error: Input audio path missing or invalid for Row {row_idx}: '{question_audio_path}'. Skipping model call.")
updated_data[row_idx][response_text_key] = "[ERROR: Missing Input Audio]"
updated_data[row_idx][response_audio_key] = None
continue # Move to the next task in the loop
# --- Construct Kimi `messages` list ---
kimi_messages = []
# 1. Parse History (if any)
if source_dataset == 'ultra' and metadata_str:
try:
metadata = json.loads(metadata_str)
history_str = metadata.get('history', '')
if history_str:
# Ensure history messages have 'message_type': 'text'
history_messages_parsed = parse_ultra_history(history_str)
kimi_messages.extend(history_messages_parsed)
except json.JSONDecodeError:
tqdm.write(f" Warning: Could not parse metadata JSON for row {row_idx}")
except Exception as hist_e:
tqdm.write(f" Warning: Error processing history for row {row_idx}: {hist_e}")
# Add elif blocks here for history parsing from other datasets if needed
# 2. Add Current User Turn (Text Prompt + Audio Path)
# Add text prompt first, if it exists and is not empty
if prompt_text and prompt_text.strip():
kimi_messages.append({"role": "user", "message_type": "text", "content": prompt_text.strip()})
# Add the user audio query using its path
kimi_messages.append({"role": "user", "message_type": "audio", "content": question_audio_path})
# Generate unique output audio filename
unique_id = str(uuid.uuid4())
output_audio_filename = f"{KIMI_MODEL_NAME}_row{row_idx}_slot{slot_i}_{unique_id}.{OUTPUT_AUDIO_FORMAT}"
output_audio_filepath = os.path.join(OUTPUT_AUDIO_ROOT_DIR, output_audio_filename)
# --- Call Kimi Model ---
# tqdm.write(f" Calling {KIMI_MODEL_NAME} for Row {row_idx}, Slot {slot_i}...") # Less verbose log
call_start_time = time.time()
response_text, saved_audio_path = call_kimi_model(
model,
kimi_messages,
KIMI_SAMPLING_PARAMS,
output_audio_filepath,
OUTPUT_AUDIO_SAMPLERATE
)
call_end_time = time.time()
audio_basename = os.path.basename(str(saved_audio_path)) if saved_audio_path else "None"
tqdm.write(f" Row {row_idx}, Slot {slot_i}: Finished in {call_end_time - call_start_time:.2f}s. Text: '{str(response_text)[:50]}...', Audio: {audio_basename}")
# Store results back into the main data list (updated_data)
updated_data[row_idx][response_text_key] = response_text # Store text/error marker
updated_data[row_idx][response_audio_key] = saved_audio_path # Store path or None
# Increment success counter based on successful generation (e.g., text isn't an error marker)
# Consider if audio generation failure should also mark task as failed.
# Current logic counts success if text seems okay.
if response_text is not None and not response_text.startswith("[ERROR"):
tasks_processed_count += 1
# --- Periodic Saving ---
processed_count_in_loop = loop_idx + 1
if processed_count_in_loop % SAVE_EVERY_N_SAMPLES == 0 or processed_count_in_loop == total_tasks_to_process:
save_checkpoint(updated_data, original_features, OUTPUT_DATASET_DIR, fallback_save_dir)
# --- STEP 6: Final Summary and Save ---
end_total_time = time.time()
print("\n" + "="*30)
print(f"STEP 6: {KIMI_MODEL_NAME} Processing Complete - Summary")
print("="*30)
print(f"Total tasks identified for processing in this run: {total_tasks_to_process}")
print(f"Total tasks successfully processed (generated text): {tasks_processed_count}") # Update definition if needed
total_duration = end_total_time - start_total_time
print(f"Total processing time for this run: {format_time(total_duration)}")
if tasks_processed_count > 0:
avg_time = total_duration / tasks_processed_count
print(f"Average time per successfully processed task in this run: {avg_time:.2f} seconds")
else:
print("Average time per task: N/A (no tasks successfully processed in this run)")
# --- Final Save ---
print("\nPerforming final save of the dataset...")
save_checkpoint(updated_data, original_features, OUTPUT_DATASET_DIR, fallback_save_dir)
print("\nScript finished.")