import json import os import re import numpy as np import argparse from sentence_transformers import SentenceTransformer # pip install sentence-transformers from sklearn.metrics.pairwise import cosine_similarity # pip install scikit-learn import random import sys # --- 1. Configuration Settings (from original script) --- # TODO: Fill in the directory containing the 8 metadata JSON files METADATA_BASE_DIR = 'metadata' # Directory containing these files # TODO: Fill in the directory containing the 8 metadata JSON files OUTPUT_DIR = "intra_task" # --- Hardcoded list of metadata filenames --- METADATA_FILENAMES = [ # '2d_Spatial457_SAT_SPAR-7M_80k_metadata.json', # '3d_RoboSpatial_PRISM_Spatial457_SAT_SPAR-7M_80.0k_metadata.json', # 'dynamic_Spatial457_SAT_SPAR-7M_80k_metadata.json', # 'perception_Spatial457_SAT_SPAR-7M_80k_metadata.json', # 'real_SPAR-7M_RoboSpatial_80.0k_metadata.json', # 'reasoning_RoboSpatial_PRISM_Spatial457_SAT_SPAR-7M_80.0k_metadata.json', # 'static_RoboSpatial_PRISM_Spatial457_SAT_SPAR-7M_80.0k_metadata.json', # 'synthetic_SAT_Spatial457_PRISM_80.0k_metadata.json' ] # --- FT files (hardcoded list of FT data) --- ADDITIONAL_FT_FILES = { # RefSpatial "RefSpatial_2D_choice_qa": "../RefSpatial_data/2D/choice_qa.json", "RefSpatial_2D_reasoning_template_qa": "../RefSpatial_data/2D/reasoning_template_qa.json", "RefSpatial_3D_choice_qa": "../RefSpatial_data/3D/choice_qa.json", "RefSpatial_3D_multi_view_qa": "../RefSpatial_data/3D/multi_view_qa.json", "RefSpatial_3D_reasoning_template_qa": "../RefSpatial_data/3D/reasoning_template_qa.json", "RefSpatial_3D_vacant_qa": "../RefSpatial_data/3D/vacant_qa.json", "RefSpatial_3D_visual_choice_qa": "../RefSpatial_data/3D/visual_choice_qa.json", # Spatial457 "Spatial457_L1_single": "../Spatial457_data/qwen_data_new/L1_single.json", "Spatial457_L2_objects": "../Spatial457_data/qwen_data_new/L2_objects.json", "Spatial457_L3_2d_spatial": "../Spatial457_data/qwen_data_new/L3_2d_spatial.json", "Spatial457_L4_occ": "../Spatial457_data/qwen_data_new/L4_occ.json", "Spatial457_L4_pose": "../Spatial457_data/qwen_data_new/L4_pose.json", "Spatial457_L5_6d_spatial": "../Spatial457_data/qwen_data_new/L5_6d_spatial.json", "Spatial457_L5_collision": "../Spatial457_data/qwen_data_new/L5_collision.json", # SPAR-7M "SPAR-7M_obj_count": "../SPAR-7M_data/qwen_data/obj_count.json", "SPAR-7M_obj_spatial_relation": "../SPAR-7M_data/qwen_data/obj_spatial_relation.json", "SPAR-7M_spatial_imagination": "../SPAR-7M_data/qwen_data/spatial_imagination.json", # SAT "SAT_action_consequence": "../SAT_data/qwen_data_new/action_consequence.json", "SAT_action_sequence": "../SAT_data/qwen_data_new/action_sequence.json", "SAT_goal_aim": "../SAT_data/qwen_data_new/goal_aim.json", "SAT_obj_movement": "../SAT_data/qwen_data_new/obj_movement.json", "SAT_other": "../SAT_data/qwen_data_new/other.json", "SAT_perspective": "../SAT_data/qwen_data_new/perspective.json", # PRISM "PRISM_train_data": "../PRISM_data/qwen_data/train_data.json", # RoboSpatial "RoboSpatial_compatibility": "../RoboSpatial_data/qwen_data/robospatial_compatibility.json", "RoboSpatial_configuration": "../RoboSpatial_data/qwen_data/robospatial_configuration.json", "RoboSpatial_context": "../RoboSpatial_data/qwen_data/robospatial_context.json", } # --- Analysis Configuration --- MODEL_NAME = 'sentence-transformers/all-mpnet-base-v2' DEFAULT_NUM_ITERATIONS = 20 MAX_SAMPLES_PER_ITERATION = 10000 # --- 2. Helper Function (from original script) --- def clean_question(text, keep_instructions=False): """Cleans question text, optionally removing instructions.""" if not isinstance(text, str): return "" cleaned_text = re.sub(r'^\n?', '', text).strip() # If we keep instructions, just do the basic clean and return if keep_instructions: return cleaned_text # If we remove instructions, apply all regex rules cleaned_text = re.sub(r'Choices:.*?(?=\nPlease answer|\n|$)', '', cleaned_text, flags=re.DOTALL | re.IGNORECASE).strip() cleaned_text = re.sub(r'\n\nChoices:\s*.*?(?=\n|$)', '', cleaned_text, flags=re.DOTALL | re.IGNORECASE).strip() cleaned_text = re.sub(r'\nPlease answer directly.*', '', cleaned_text, flags=re.DOTALL | re.IGNORECASE).strip() cleaned_text = re.sub(r'The options describe the spatial relationship.*?(\nChoose|\nPlease select|\nPick|\nSelect)', '', cleaned_text, flags=re.DOTALL | re.IGNORECASE).strip() cleaned_text = re.sub(r'(\nChoose|\nPlease select|\nPick|\nSelect)\s+the (correct|right|appropriate) (response|option|answer).*?Your answer can only include.*?$', '', cleaned_text, flags=re.DOTALL | re.IGNORECASE).strip() cleaned_text = re.sub(r'\.?\s*Final answer should be.*?$', '', cleaned_text, flags=re.DOTALL | re.IGNORECASE).strip() cleaned_text = re.sub(r'\.?\s*Your final answer should be formatted.*?points\.', '', cleaned_text, flags=re.DOTALL | re.IGNORECASE).strip() cleaned_text = re.sub(r'\.?\s*Please use the world coordinate system.*?objects\.', '', cleaned_text, flags=re.DOTALL | re.IGNORECASE).strip() cleaned_text = cleaned_text.rstrip('?.') return cleaned_text.strip() # --- 3. Main Analysis Function --- def analyze_ft_intra_similarity(metadata_dir, metadata_files, additional_files, model_name, num_iterations, keep_instructions): """ Calculates intra-task similarity for all specified fine-tuning datasets. """ print(f"Using Sentence Transformer model: {model_name}") print(f"Keep instructions: {keep_instructions}") print(f"Running {num_iterations} iterations...\n") try: model = SentenceTransformer(model_name) except Exception as e: print(f"!!! ERROR loading model '{model_name}': {e}") return None, None, None # --- Load FT Data --- print("Processing Fine-tuning source data...") ft_questions = {} metadata_task_names = [] additional_task_names = [] # 1. Load from metadata print(f"Reading {len(metadata_files)} metadata files from '{metadata_dir}'...") for meta_filename in metadata_files: meta_file_path = os.path.join(metadata_dir, meta_filename) if not os.path.exists(meta_file_path): print(f"Warning: Meta file not found: {meta_file_path}") continue try: with open(meta_file_path, 'r', encoding='utf-8') as f: metadata = json.load(f) if 'source_files' in metadata: for src in metadata['source_files']: task_name = src.get('file_path') # Use relative path as task name if not task_name: continue full_path = os.path.normpath(os.path.join("..", task_name)) metadata_task_names.append(task_name) if task_name not in ft_questions: ft_questions[task_name] = [] if not os.path.exists(full_path): print(f"Warning: Source file not found: {full_path}") continue # Load questions from this source file with open(full_path, 'r', encoding='utf-8') as f_data: data = json.load(f_data) for item in data: q = None if 'conversations' in item and item['conversations']: q = item['conversations'][0].get('value') elif 'question' in item: q = item.get('question') cleaned_q = clean_question(q, keep_instructions=keep_instructions) if cleaned_q: ft_questions[task_name].append(cleaned_q) except Exception as e: print(f"Warning: Could not process meta file {meta_filename}: {e}") # 2. Add additional files print(f"Adding {len(additional_files)} additional FT files...") for task_name, relative_path in additional_files.items(): full_path = os.path.normpath(relative_path) additional_task_names.append(task_name) if task_name not in ft_questions: ft_questions[task_name] = [] if not os.path.exists(full_path): print(f"Warning: Additional file not found: {full_path}") continue try: # Load questions from this source file with open(full_path, 'r', encoding='utf-8') as f_data: data = json.load(f_data) for item in data: q = None if 'conversations' in item and item['conversations']: q = item['conversations'][0].get('value') elif 'question' in item: q = item.get('question') cleaned_q = clean_question(q, keep_instructions=keep_instructions) if cleaned_q: ft_questions[task_name].append(cleaned_q) except Exception as e: print(f"Warning: Could not process additional file {full_path}: {e}") print(f"Loaded questions from {len(ft_questions)} unique FT tasks.") # --- Pre-calculate all embeddings --- print("\nCalculating all FT task embeddings (once)...") ft_embeddings = {} for q_type, questions in ft_questions.items(): if len(questions) < 2: print(f" - Skipping {q_type} (only {len(questions)} question(s))") continue print(f" - Encoding {q_type} ({len(questions)} q's)") try: ft_embeddings[q_type] = model.encode(questions, show_progress_bar=True) except Exception as e: print(f"!!! ERROR encoding {q_type}: {e}") if not ft_embeddings: print("!!! ERROR: No tasks with sufficient questions to encode.") return None, None, None print("\nStarting similarity iterations...") task_similarity_scores = {q_type: [] for q_type in ft_embeddings} for i in range(num_iterations): if (i + 1) % 10 == 0 or i == 0: print(f" Iteration {i+1}/{num_iterations}...") for q_type, embeddings_list in ft_embeddings.items(): total_count = len(embeddings_list) k = min(total_count, MAX_SAMPLES_PER_ITERATION) sampled_indices = random.sample(range(total_count), k) mid_point = k // 2 group_a_indices = sampled_indices[:mid_point] group_b_indices = sampled_indices[mid_point:] group_a_embeds = [embeddings_list[idx] for idx in group_a_indices] group_b_embeds = [embeddings_list[idx] for idx in group_b_indices] avg_a = np.mean(group_a_embeds, axis=0).reshape(1, -1) avg_b = np.mean(group_b_embeds, axis=0).reshape(1, -1) similarity = cosine_similarity(avg_a, avg_b)[0][0] # --- FIX: Cast numpy.float32 to a standard python float --- task_similarity_scores[q_type].append(float(similarity)) print("Iterations complete.\n") # --- Calculate Final Statistics --- final_stats = {} all_scores = [] for q_type, scores in task_similarity_scores.items(): if scores: mean_sim = np.mean(scores) std_sim = np.std(scores) final_stats[q_type] = { 'mean': mean_sim, 'std': std_sim, 'iterations': len(scores), 'num_questions': len(ft_embeddings[q_type]) } all_scores.extend(scores) # Calculate overall average if all_scores: overall_mean = np.mean(all_scores) overall_std = np.std(all_scores) final_stats['--OVERALL--'] = { 'mean': overall_mean, 'std': overall_std, 'iterations': num_iterations, 'num_questions': 'N/A' } # Return stats and the lists of task names for grouped printing return final_stats, metadata_task_names, additional_task_names # --- 4. Main Execution Block --- if __name__ == "__main__": parser = argparse.ArgumentParser(description="Calculate intra-task similarity for all Fine-Tuning source datasets.") parser.add_argument('--keep_instructions', action='store_true', help="Keep instruction text in questions during embedding.") parser.add_argument('--num_iterations', type=int, default=DEFAULT_NUM_ITERATIONS, help=f"Number of random splits to perform (default: {DEFAULT_NUM_ITERATIONS}).") args = parser.parse_args() if not METADATA_BASE_DIR: print("!!! ERROR: Please fill in the METADATA_BASE_DIR variable in the script.") sys.exit() stats, metadata_tasks, additional_tasks = analyze_ft_intra_similarity( metadata_dir=METADATA_BASE_DIR, metadata_files=METADATA_FILENAMES, additional_files=ADDITIONAL_FT_FILES, model_name=MODEL_NAME, num_iterations=args.num_iterations, keep_instructions=args.keep_instructions ) if stats: print("--- Fine-Tuning Intra-Task Similarity Report ---") print(f"(Based on {args.num_iterations} random split iterations)\n") valid_tasks = [task for task in stats if task != '--OVERALL--'] if not valid_tasks: print("No valid tasks with scores found.") sys.exit() max_name_len = max(len(t) for t in valid_tasks) max_name_len = max(35, max_name_len) header = f"{'Task Name':<{max_name_len}} | {'Questions':>10} | {'Mean Similarity':>18} | {'Std. Deviation':>18}" print(header) print("-" * len(header)) # --- (Console Print) Group 1: Additional FT Tasks --- print(f"\n--- Additional FT Tasks (e.g., RefSpatial) ---") additional_tasks_found = 0 for task in sorted(list(set(additional_tasks))): if task in stats: data = stats[task] print(f"{task:<{max_name_len}} | {data['num_questions']:>10} | {data['mean']:>18.4f} | {data['std']:>18.4f}") additional_tasks_found += 1 if additional_tasks_found == 0: print("No valid results for Additional FT Tasks.") # --- (Console Print) Group 2: Metadata-Discovered Tasks --- print(f"\n--- Metadata-Discovered Tasks (e.g., VQA, SPAR) ---") metadata_tasks_found = 0 for task in sorted(list(set(metadata_tasks))): if task in stats: data = stats[task] print(f"{task:<{max_name_len}} | {data['num_questions']:>10} | {data['mean']:>18.4f} | {data['std']:>18.4f}") metadata_tasks_found += 1 if metadata_tasks_found == 0: print("No valid results for Metadata-Discovered Tasks.") # --- (Console Print) Overall Summary --- if '--OVERALL--' in stats: data = stats['--OVERALL--'] print("\n" + "-" * len(header)) print(f"{'--OVERALL--':<{max_name_len}} | {data['num_questions']:>10} | {data['mean']:>18.4f} | {data['std']:>18.4f}") try: os.makedirs(OUTPUT_DIR, exist_ok=True) suffix = "_with_instructions" if args.keep_instructions else "" report_filename = f"fixed_ft_datasets_intra_task_report{suffix}.json" report_save_path = os.path.join(OUTPUT_DIR, report_filename) report_data = { "report_summary": { "model_name": MODEL_NAME, "num_iterations": args.num_iterations, "keep_instructions": args.keep_instructions }, "overall_stats": stats.get('--OVERALL--', {}), "additional_ft_tasks": {}, "metadata_discovered_tasks": {} } # 1. Additional FT Tasks (RefSpatial, etc.) for task in sorted(list(set(additional_tasks))): if task in stats: report_data["additional_ft_tasks"][task] = stats[task] # 2. Metadata-Discovered Tasks (VQA, SPAR, etc.) for task in sorted(list(set(metadata_tasks))): if task in stats: report_data["metadata_discovered_tasks"][task] = stats[task] with open(report_save_path, 'w', encoding='utf-8') as f: json.dump(report_data, f, indent=2) print(f"\nStructured summary report saved to {report_save_path}") except Exception as e: print(f"\n!!! Could not save final summary report to JSON: {e}") print("\n--- Analysis finished ---")