import os import json import time import re # For parsing history import uuid # For generating unique filenames import torch # Kimi might return tensors import soundfile as sf # For saving Kimi audio output import sys from datasets import load_from_disk, Dataset, Features, Audio, Value from dotenv import load_dotenv import datetime # For ETA formatting from tqdm import tqdm # Import tqdm import traceback # For detailed error printing # --- Kimi-Audio Project Path Setup --- # <--- *** IMPORTANT: Update this path to the PARENT directory containing the 'kimia_infer' folder *** ---> kimia_project_parent_dir = "/home/chenyifu/audio-r1/r1-a/response_generation/Kimi-Audio" # Check if the path exists and add it to sys.path if os.path.isdir(kimia_project_parent_dir): if kimia_project_parent_dir not in sys.path: sys.path.insert(0, kimia_project_parent_dir) print(f"Added '{kimia_project_parent_dir}' to Python path.") # Try importing KimiAudio only after potentially adding the path try: from kimia_infer.api.kimia import KimiAudio # Kimi model class except ImportError as import_err: print(f"Error: Could not import KimiAudio from '{kimia_project_parent_dir}'.") print(f"ImportError: {import_err}") print("Please ensure the 'kimia_infer' directory exists within the specified path and check dependencies.") exit(1) else: print(f"Error: Kimi project parent directory not found: '{kimia_project_parent_dir}'") print("Please update the 'kimia_project_parent_dir' variable in the script.") exit(1) # --- Configuration --- load_dotenv() # Load environment variables if needed (e.g., API keys, though not typical for local Kimi) # 1. Model & Tokenizer Setup (Kimi Specific) KIMI_MODEL_NAME = "kimi_audio" # Identifier used in your dataset's model_N columns KIMI_MODEL_PATH = "/home/chenyifu/audio-r1/r1-a/response_generation/Kimi-Audio/checkpoint/Kimi-Audio-7B-Instruct" # Path to your Kimi model checkpoint directory # KIMI_DEVICE = 'cuda' # KimiAudio class likely handles device selection based on availability. Verify its internal logic if issues arise. # KIMI_DTYPE = torch.bfloat16 # KimiAudio likely handles dtype internally. # 2. Dataset Paths INPUT_DATASET_DIR = "/home/chenyifu/audio-r1/r1-a/dataset/preference_sampling_tasks" # Original source OUTPUT_DATASET_DIR = "/home/chenyifu/audio-r1/r1-a/dataset/preference_tasks_with_kimi" # Where Kimi processed data is saved/resumed from # 3. Output Audio Configuration (Kimi Specific) OUTPUT_AUDIO_ROOT_DIR = "/home/chenyifu/audio-r1/r1-a/generated_audio/kimi" # Where Kimi generated audio files are saved OUTPUT_AUDIO_FORMAT = "wav" OUTPUT_AUDIO_SAMPLERATE = 24000 # Kimi example uses 24kHz output. Confirm this matches your model's expected/native output SR. # 4. Kimi Call Settings (Based on example, adjust as needed) KIMI_SAMPLING_PARAMS = { "audio_temperature": 0.8, "audio_top_k": 10, "text_temperature": 0.0, # 0.0 for deterministic text, increase for more variety "text_top_k": 5, # Relevant if text_temperature > 0 "audio_repetition_penalty": 1.0, "audio_repetition_window_size": 64, "text_repetition_penalty": 1.0, "text_repetition_window_size": 16, # "max_new_tokens": 128 # Add if needed and supported by KimiAudio.generate } KIMI_OUTPUT_TYPE = "both" # Generate both audio and text # 5. Periodic Save Settings SAVE_EVERY_N_SAMPLES = 50 # Save after processing this many samples # --- Helper Functions --- def format_time(seconds): """Formats seconds into a human-readable string H:MM:SS""" if seconds < 0: return "N/A" return str(datetime.timedelta(seconds=int(seconds))) # REMOVED load_audio_minicpm - Kimi takes the path directly def parse_ultra_history(history_str): """Parses the specific history string format from ultra metadata for Kimi.""" messages = [] # Relaxed pattern to capture content even if tags are slightly off or whitespace varies pattern = re.compile(r"\[\s*(USER|ASSISTANT)\s*\]\s*([\s\S]*?)(?=\s*\[\s*(?:USER|ASSISTANT)\s*\]|$)") matches = pattern.findall(history_str) if not matches and history_str and history_str.strip(): # Simple fallback if standard pattern fails but there's content if history_str.lower().startswith("user:") or history_str.lower().startswith("[user]"): role = "user" content = re.sub(r"^(user:|\[user\])\s*", "", history_str, flags=re.IGNORECASE).strip() if content: messages.append({"role": role, "message_type": "text", "content": content}) # Add Kimi message_type elif history_str.lower().startswith("assistant:") or history_str.lower().startswith("[assistant]"): role = "assistant" content = re.sub(r"^(assistant:|\[assistant\])\s*", "", history_str, flags=re.IGNORECASE).strip() if content: messages.append({"role": role, "message_type": "text", "content": content}) # Add Kimi message_type else: print(f"Warning: Could not parse history string format: {history_str[:100]}...") return messages # Return whatever was parsed, even if empty for role_tag, content in matches: role = role_tag.strip().lower() cleaned_content = content.strip() if cleaned_content: # IMPORTANT: Add message_type='text' for Kimi history messages.append({"role": role, "message_type": "text", "content": cleaned_content}) return messages # --- Kimi Model Interaction Function --- def call_kimi_model(model, messages_input, sampling_params, output_audio_filepath, output_sample_rate): """Calls the Kimi-Audio model, saves audio, returns text and audio path.""" try: # 1. Ensure Output Directory Exists os.makedirs(os.path.dirname(output_audio_filepath), exist_ok=True) # 2. Call Kimi's Generate Function wav_output, text_output = model.generate( messages_input, **sampling_params, output_type=KIMI_OUTPUT_TYPE # Use 'both' ) # 3. Process and Save Audio Output saved_audio_path = None if wav_output is not None and isinstance(wav_output, torch.Tensor) and wav_output.numel() > 0: # Check if tensor is not empty try: # Ensure tensor is on CPU, reshape (if needed, often view(-1)), convert to numpy # Check KimiAudio output format - might already be 1D or need specific shape audio_data = wav_output.detach().cpu().view(-1).numpy() # Ensure data is float32 or int16 as supported by soundfile/WAV if audio_data.dtype != 'float32': # Attempt conversion, potentially scale if it's int # print(f" Info: Converting Kimi audio output from {audio_data.dtype} to float32 for saving.") if np.issubdtype(audio_data.dtype, np.integer): # Scale integer types to [-1, 1] float range if necessary # Example: if int16 -> audio_data = audio_data.astype(np.float32) / 32768.0 # Adjust scaling based on the actual integer range if known audio_data = audio_data.astype(np.float32) # Simplest conversion, might need scaling else: audio_data = audio_data.astype(np.float32) sf.write(output_audio_filepath, audio_data, output_sample_rate) # Check if file was actually created and has size if os.path.exists(output_audio_filepath) and os.path.getsize(output_audio_filepath) > 100: # Check for a reasonable size threshold saved_audio_path = output_audio_filepath else: print(f" Error: Kimi generate finished but output audio file seems empty or too small at {output_audio_filepath}") if os.path.exists(output_audio_filepath): try: os.remove(output_audio_filepath) except OSError as rm_err: print(f" Warning: Could not remove empty/small file {output_audio_filepath}: {rm_err}") except ImportError: print("Error: NumPy library not found. Please install it (`pip install numpy`)") return "[ERROR: NumPy Missing]", None # Indicate failure clearly except Exception as sf_err: print(f" Error saving Kimi audio output to {output_audio_filepath}: {sf_err}") traceback.print_exc() if os.path.exists(output_audio_filepath): try: os.remove(output_audio_filepath) except OSError as rm_err: print(f" Warning: Could not remove potentially corrupt file {output_audio_filepath}: {rm_err}") elif wav_output is None: print(" Warning: Kimi model did not return an audio tensor (wav_output is None).") elif isinstance(wav_output, torch.Tensor) and wav_output.numel() == 0: print(" Warning: Kimi model returned an empty audio tensor.") else: print(f" Warning: Kimi model returned unexpected audio output type: {type(wav_output)}. Expected torch.Tensor.") # 4. Process Text Output response_text_cleaned = "" if isinstance(text_output, str): response_text_cleaned = text_output.strip() elif text_output is not None: response_text_cleaned = str(text_output).strip() # Convert just in case else: # If text is None but audio might exist, use a specific marker if saved_audio_path: response_text_cleaned = "[Audio Generated, No Text Output]" else: response_text_cleaned = "[ERROR: No Text Output]" # Return text (even if audio failed) and the path (or None) return response_text_cleaned, saved_audio_path except Exception as e: print(f"\n --- Error during Kimi model call ---") # Avoid printing potentially huge message list directly first_message = messages_input[0] if messages_input else "N/A" last_message_content = messages_input[-1]['content'] if messages_input else "N/A" if isinstance(last_message_content, str) and len(last_message_content) > 100 : last_message_preview = last_message_content[:100] + "..." else: last_message_preview = last_message_content print(f" Input Messages Info: Count={len(messages_input)}, First={first_message}, Last Content Preview='{last_message_preview}'") print(f" Exception Type: {type(e).__name__}") print(f" Error Details: {e}") print(" Traceback:") traceback.print_exc() print(" --- End Error Details ---") # Attempt cleanup of potentially incomplete output file if 'output_audio_filepath' in locals() and os.path.exists(output_audio_filepath): try: os.remove(output_audio_filepath) except OSError as rm_err: print(f" Warning: Could not remove file {output_audio_filepath} after error: {rm_err}") # Return clear error markers return "[ERROR: Kimi Model Call Failed]", None # --- Dataset Saving Function (Modified for Kimi context) --- def save_checkpoint(data_list, features, output_dir, fallback_dir=None): """Saves the current state of the data list as a Hugging Face Dataset.""" if not data_list: print("\nSkipping checkpoint save: data list is empty.") return print(f"\nSaving checkpoint with {len(data_list)} rows to {output_dir}...") try: # Ensure the list contains dictionaries data_to_save = [dict(item) for item in data_list] # --- Feature Check/Adaptation (Optional but recommended) --- # Sometimes saving fails if data types changed unexpectedly (e.g., None -> str) # It's safer to create the Dataset *without* features first, then cast temp_dataset = Dataset.from_list(data_to_save) # Now cast to the original features, allowing potential None/type mismatches # This might raise warnings but is often more robust than direct from_list with features updated_dataset = temp_dataset.cast(features) # --- End Feature Check --- # Ensure output directory exists before saving os.makedirs(output_dir, exist_ok=True) updated_dataset.save_to_disk(output_dir) print("Checkpoint saved successfully.") except Exception as e: print(f"Error saving checkpoint dataset using save_to_disk to {output_dir}: {e}") traceback.print_exc() if fallback_dir: # Use Kimi-specific name in fallback path fallback_path = os.path.join(fallback_dir, f"updated_{KIMI_MODEL_NAME}_data_checkpoint_{int(time.time())}.jsonl") print(f"Attempting to save data as JSON Lines fallback to: {fallback_path}") try: os.makedirs(fallback_dir, exist_ok=True) with open(fallback_path, 'w', encoding='utf-8') as f: # Reuse data_to_save which is already list of dicts for item in data_to_save: # Ensure all values are serializable serializable_item = {} for k, v in item.items(): if isinstance(v, (datetime.datetime, datetime.date)): serializable_item[k] = v.isoformat() elif isinstance(v, bytes): serializable_item[k] = v.decode('utf-8', errors='ignore') elif isinstance(v, torch.Tensor): # Handle potential tensors if not caught earlier print(f" Warning: Found unexpected Tensor for key '{k}' in fallback save. Converting to list.") serializable_item[k] = v.tolist() elif not isinstance(v, (str, int, float, bool, list, dict, type(None))): print(f" Warning: Converting non-standard type {type(v)} for key '{k}' to string for JSON fallback.") serializable_item[k] = str(v) else: serializable_item[k] = v try: f.write(json.dumps(serializable_item, ensure_ascii=False) + '\n') except TypeError as json_type_err: print(f" Skipping row due to JSON serialization error: {json_type_err} in item part: {k}={v}") print("Fallback JSON Lines checkpoint saved successfully.") except Exception as json_e: print(f"Error saving fallback JSON Lines checkpoint: {json_e}") # ============================================= # --- Main Processing Logic --- # ============================================= # --- STEP 1: Dataset Loading (Modified for Resumption) --- print("="*30) print("STEP 1: Loading Dataset") print("="*30) dataset = None original_features = None # Initialize # Check if the Kimi-specific output directory exists if os.path.exists(OUTPUT_DATASET_DIR): print(f"Found existing Kimi processed dataset directory at: {OUTPUT_DATASET_DIR}") print("Attempting to load it to resume processing...") try: dataset = load_from_disk(OUTPUT_DATASET_DIR) original_features = dataset.features # Get features from the loaded dataset print(f"Resumed Kimi dataset loaded successfully with {len(dataset)} rows.") print(f"Features from resumed dataset: {original_features}") except Exception as e: print(f"Warning: Error loading existing Kimi dataset from {OUTPUT_DATASET_DIR}: {e}") traceback.print_exc() print("Will attempt to load the original input dataset instead.") dataset = None # Reset dataset variable else: print(f"No existing Kimi processed dataset found at {OUTPUT_DATASET_DIR}.") print("Will attempt to load the original input dataset.") # If dataset is still None, load from the original input directory if dataset is None: print(f"\nLoading original input dataset from: {INPUT_DATASET_DIR}") if not os.path.exists(INPUT_DATASET_DIR): print(f"FATAL: Original input dataset directory not found at {INPUT_DATASET_DIR}") exit(1) try: dataset = load_from_disk(INPUT_DATASET_DIR) original_features = dataset.features # Get features from the input dataset print(f"Original input dataset loaded successfully with {len(dataset)} rows.") print(f"Features from input dataset: {original_features}") except Exception as e: print(f"FATAL: Error loading original input dataset from {INPUT_DATASET_DIR}: {e}") traceback.print_exc() exit(1) # --- Ensure dataset and features were loaded --- if dataset is None or original_features is None: print("FATAL: Failed to load any dataset. Exiting.") exit(1) # --- End Dataset Loading --- # --- STEP 2: Pre-computation - Identify Kimi Tasks --- print("\n" + "="*30) print(f"STEP 2: Identifying '{KIMI_MODEL_NAME}' Tasks to Process") print("="*30) pkusafe_tasks_indices = [] other_tasks_indices = [] # Iterate through the loaded dataset structure for idx, row in enumerate(dataset): source_dataset = row.get('source_dataset') processed_in_row = False # Flag to ensure we only pick one Kimi slot per row for i in range(1, 4): # Check slots 1, 2, 3 model_key = f"model_{i}" response_text_key = f"response_text_{i}" # Check if the slot is assigned to Kimi and is NOT yet filled (text response missing) is_target_model_task = row.get(model_key) == KIMI_MODEL_NAME is_unfilled = not row.get(response_text_key) # True if None or empty string if is_target_model_task and is_unfilled and not processed_in_row: task_info = (idx, i) # Store tuple of (original_row_index, slot_index) if source_dataset == 'pkusafe': pkusafe_tasks_indices.append(task_info) else: other_tasks_indices.append(task_info) processed_in_row = True # Mark row as having a task identified # Combine lists, prioritizing pkusafe tasks_to_process_indices = pkusafe_tasks_indices + other_tasks_indices total_tasks_to_process = len(tasks_to_process_indices) print(f"Found {len(pkusafe_tasks_indices)} 'pkusafe' tasks and {len(other_tasks_indices)} other tasks requiring '{KIMI_MODEL_NAME}' processing in the loaded dataset.") print(f"Total tasks remaining to process: {total_tasks_to_process}") if total_tasks_to_process == 0: print(f"\nNo remaining tasks to process for {KIMI_MODEL_NAME} based on the loaded dataset.") # Optionally, perform a final save for consistency # print("Performing a final save to ensure consistency...") # final_data_list = [dict(row) for row in dataset] # fallback_save_dir_final = os.path.join(os.path.dirname(OUTPUT_DATASET_DIR), f"{KIMI_MODEL_NAME}_checkpoints_fallback") # save_checkpoint(final_data_list, original_features, OUTPUT_DATASET_DIR, fallback_save_dir_final) print("Exiting.") exit(0) # --- End Pre-computation Step --- # --- STEP 3: Load Kimi Model --- print("\n" + "="*30) print(f"STEP 3: Loading {KIMI_MODEL_NAME} Model") print("="*30) try: # Load Kimi model using the class imported earlier model = KimiAudio(model_path=KIMI_MODEL_PATH, load_detokenizer=True) # Assuming detokenizer is needed based on example print(f"{KIMI_MODEL_NAME} model loaded successfully from {KIMI_MODEL_PATH}.") except NameError: print("FATAL: KimiAudio class not defined. Import likely failed earlier.") exit(1) except Exception as e: print(f"Error loading {KIMI_MODEL_NAME} model from {KIMI_MODEL_PATH}: {e}") traceback.print_exc() exit(1) # --- STEP 4: Prepare for Processing --- print("\n" + "="*30) print(f"STEP 4: Preparing for {KIMI_MODEL_NAME} Processing") print("="*30) # Create output directories if they don't exist os.makedirs(OUTPUT_AUDIO_ROOT_DIR, exist_ok=True) os.makedirs(OUTPUT_DATASET_DIR, exist_ok=True) # Define and create fallback directory for Kimi fallback_save_dir = os.path.join(os.path.dirname(OUTPUT_DATASET_DIR), f"{KIMI_MODEL_NAME}_checkpoints_fallback") os.makedirs(fallback_save_dir, exist_ok=True) print(f"Audio outputs will be saved in: {OUTPUT_AUDIO_ROOT_DIR}") print(f"Dataset checkpoints will be saved in: {OUTPUT_DATASET_DIR}") print(f"Fallback checkpoints (JSONL) in: {fallback_save_dir}") # Create a mutable list of dictionaries from the loaded dataset for updates updated_data = [dict(row) for row in dataset] # Convert each row to a dictionary tasks_processed_count = 0 # Count successful completions for average time calculation start_total_time = time.time() # --- STEP 5: Start Processing Loop --- print("\n" + "="*30) print(f"STEP 5: Starting {KIMI_MODEL_NAME} Processing Loop ({total_tasks_to_process} Tasks)") print("="*30) # Use tqdm for the progress bar, iterating over the identified task indices pbar = tqdm(enumerate(tasks_to_process_indices), total=total_tasks_to_process, desc=f"Processing {KIMI_MODEL_NAME} Tasks") for loop_idx, (row_idx, slot_i) in pbar: # Get the row data *from our mutable list* using the original index row = updated_data[row_idx] # This is already a dictionary # Set description in tqdm dynamically pbar.set_description(f"Processing Row {row_idx}, Slot {slot_i}") prompt_text_key = f"prompt_text_{slot_i}" response_text_key = f"response_text_{slot_i}" response_audio_key = f"response_audio_path_{slot_i}" model_key = f"model_{slot_i}" # --- Sanity Check --- if row.get(model_key) != KIMI_MODEL_NAME: tqdm.write(f" Skipping Row {row_idx}, Slot {slot_i}: Model is '{row.get(model_key)}', not '{KIMI_MODEL_NAME}'.") continue if row.get(response_text_key): tqdm.write(f" Skipping Row {row_idx}, Slot {slot_i}: Already has response text '{str(row.get(response_text_key))[:50]}...'.") continue # --- Prepare Kimi Model Inputs --- prompt_text = row.get(prompt_text_key, "") question_audio_path = row.get('question_audio') metadata_str = row.get('metadata', "{}") source_dataset = row.get('source_dataset') # Check for essential input audio path validity if not question_audio_path or not os.path.exists(question_audio_path): tqdm.write(f" Error: Input audio path missing or invalid for Row {row_idx}: '{question_audio_path}'. Skipping model call.") updated_data[row_idx][response_text_key] = "[ERROR: Missing Input Audio]" updated_data[row_idx][response_audio_key] = None continue # Move to the next task in the loop # --- Construct Kimi `messages` list --- kimi_messages = [] # 1. Parse History (if any) if source_dataset == 'ultra' and metadata_str: try: metadata = json.loads(metadata_str) history_str = metadata.get('history', '') if history_str: # Ensure history messages have 'message_type': 'text' history_messages_parsed = parse_ultra_history(history_str) kimi_messages.extend(history_messages_parsed) except json.JSONDecodeError: tqdm.write(f" Warning: Could not parse metadata JSON for row {row_idx}") except Exception as hist_e: tqdm.write(f" Warning: Error processing history for row {row_idx}: {hist_e}") # Add elif blocks here for history parsing from other datasets if needed # 2. Add Current User Turn (Text Prompt + Audio Path) # Add text prompt first, if it exists and is not empty if prompt_text and prompt_text.strip(): kimi_messages.append({"role": "user", "message_type": "text", "content": prompt_text.strip()}) # Add the user audio query using its path kimi_messages.append({"role": "user", "message_type": "audio", "content": question_audio_path}) # Generate unique output audio filename unique_id = str(uuid.uuid4()) output_audio_filename = f"{KIMI_MODEL_NAME}_row{row_idx}_slot{slot_i}_{unique_id}.{OUTPUT_AUDIO_FORMAT}" output_audio_filepath = os.path.join(OUTPUT_AUDIO_ROOT_DIR, output_audio_filename) # --- Call Kimi Model --- # tqdm.write(f" Calling {KIMI_MODEL_NAME} for Row {row_idx}, Slot {slot_i}...") # Less verbose log call_start_time = time.time() response_text, saved_audio_path = call_kimi_model( model, kimi_messages, KIMI_SAMPLING_PARAMS, output_audio_filepath, OUTPUT_AUDIO_SAMPLERATE ) call_end_time = time.time() audio_basename = os.path.basename(str(saved_audio_path)) if saved_audio_path else "None" tqdm.write(f" Row {row_idx}, Slot {slot_i}: Finished in {call_end_time - call_start_time:.2f}s. Text: '{str(response_text)[:50]}...', Audio: {audio_basename}") # Store results back into the main data list (updated_data) updated_data[row_idx][response_text_key] = response_text # Store text/error marker updated_data[row_idx][response_audio_key] = saved_audio_path # Store path or None # Increment success counter based on successful generation (e.g., text isn't an error marker) # Consider if audio generation failure should also mark task as failed. # Current logic counts success if text seems okay. if response_text is not None and not response_text.startswith("[ERROR"): tasks_processed_count += 1 # --- Periodic Saving --- processed_count_in_loop = loop_idx + 1 if processed_count_in_loop % SAVE_EVERY_N_SAMPLES == 0 or processed_count_in_loop == total_tasks_to_process: save_checkpoint(updated_data, original_features, OUTPUT_DATASET_DIR, fallback_save_dir) # --- STEP 6: Final Summary and Save --- end_total_time = time.time() print("\n" + "="*30) print(f"STEP 6: {KIMI_MODEL_NAME} Processing Complete - Summary") print("="*30) print(f"Total tasks identified for processing in this run: {total_tasks_to_process}") print(f"Total tasks successfully processed (generated text): {tasks_processed_count}") # Update definition if needed total_duration = end_total_time - start_total_time print(f"Total processing time for this run: {format_time(total_duration)}") if tasks_processed_count > 0: avg_time = total_duration / tasks_processed_count print(f"Average time per successfully processed task in this run: {avg_time:.2f} seconds") else: print("Average time per task: N/A (no tasks successfully processed in this run)") # --- Final Save --- print("\nPerforming final save of the dataset...") save_checkpoint(updated_data, original_features, OUTPUT_DATASET_DIR, fallback_save_dir) print("\nScript finished.")