| import json |
| import os |
| import re |
| import numpy as np |
| import argparse |
| from sentence_transformers import SentenceTransformer |
| from sklearn.metrics.pairwise import cosine_similarity |
| import random |
| import sys |
|
|
| |
| |
| METADATA_BASE_DIR = 'metadata' |
|
|
| |
| OUTPUT_DIR = "intra_task" |
|
|
| |
| METADATA_FILENAMES = [ |
| |
| |
| |
| |
| |
| |
| |
| |
| ] |
|
|
| |
| ADDITIONAL_FT_FILES = { |
| |
| "RefSpatial_2D_choice_qa": "../RefSpatial_data/2D/choice_qa.json", |
| "RefSpatial_2D_reasoning_template_qa": "../RefSpatial_data/2D/reasoning_template_qa.json", |
| "RefSpatial_3D_choice_qa": "../RefSpatial_data/3D/choice_qa.json", |
| "RefSpatial_3D_multi_view_qa": "../RefSpatial_data/3D/multi_view_qa.json", |
| "RefSpatial_3D_reasoning_template_qa": "../RefSpatial_data/3D/reasoning_template_qa.json", |
| "RefSpatial_3D_vacant_qa": "../RefSpatial_data/3D/vacant_qa.json", |
| "RefSpatial_3D_visual_choice_qa": "../RefSpatial_data/3D/visual_choice_qa.json", |
| |
| "Spatial457_L1_single": "../Spatial457_data/qwen_data_new/L1_single.json", |
| "Spatial457_L2_objects": "../Spatial457_data/qwen_data_new/L2_objects.json", |
| "Spatial457_L3_2d_spatial": "../Spatial457_data/qwen_data_new/L3_2d_spatial.json", |
| "Spatial457_L4_occ": "../Spatial457_data/qwen_data_new/L4_occ.json", |
| "Spatial457_L4_pose": "../Spatial457_data/qwen_data_new/L4_pose.json", |
| "Spatial457_L5_6d_spatial": "../Spatial457_data/qwen_data_new/L5_6d_spatial.json", |
| "Spatial457_L5_collision": "../Spatial457_data/qwen_data_new/L5_collision.json", |
| |
| "SPAR-7M_obj_count": "../SPAR-7M_data/qwen_data/obj_count.json", |
| "SPAR-7M_obj_spatial_relation": "../SPAR-7M_data/qwen_data/obj_spatial_relation.json", |
| "SPAR-7M_spatial_imagination": "../SPAR-7M_data/qwen_data/spatial_imagination.json", |
| |
| "SAT_action_consequence": "../SAT_data/qwen_data_new/action_consequence.json", |
| "SAT_action_sequence": "../SAT_data/qwen_data_new/action_sequence.json", |
| "SAT_goal_aim": "../SAT_data/qwen_data_new/goal_aim.json", |
| "SAT_obj_movement": "../SAT_data/qwen_data_new/obj_movement.json", |
| "SAT_other": "../SAT_data/qwen_data_new/other.json", |
| "SAT_perspective": "../SAT_data/qwen_data_new/perspective.json", |
| |
| "PRISM_train_data": "../PRISM_data/qwen_data/train_data.json", |
| |
| "RoboSpatial_compatibility": "../RoboSpatial_data/qwen_data/robospatial_compatibility.json", |
| "RoboSpatial_configuration": "../RoboSpatial_data/qwen_data/robospatial_configuration.json", |
| "RoboSpatial_context": "../RoboSpatial_data/qwen_data/robospatial_context.json", |
| } |
|
|
| |
| MODEL_NAME = 'sentence-transformers/all-mpnet-base-v2' |
| DEFAULT_NUM_ITERATIONS = 20 |
| MAX_SAMPLES_PER_ITERATION = 10000 |
|
|
| |
| def clean_question(text, keep_instructions=False): |
| """Cleans question text, optionally removing instructions.""" |
| if not isinstance(text, str): return "" |
| cleaned_text = re.sub(r'^<image>\n?', '', text).strip() |
| |
| |
| if keep_instructions: |
| return cleaned_text |
| |
| |
| cleaned_text = re.sub(r'Choices:.*?(?=\nPlease answer|\n|$)', '', cleaned_text, flags=re.DOTALL | re.IGNORECASE).strip() |
| cleaned_text = re.sub(r'\n\nChoices:\s*.*?(?=\n|$)', '', cleaned_text, flags=re.DOTALL | re.IGNORECASE).strip() |
| cleaned_text = re.sub(r'\nPlease answer directly.*', '', cleaned_text, flags=re.DOTALL | re.IGNORECASE).strip() |
| cleaned_text = re.sub(r'The options describe the spatial relationship.*?(\nChoose|\nPlease select|\nPick|\nSelect)', '', cleaned_text, flags=re.DOTALL | re.IGNORECASE).strip() |
| cleaned_text = re.sub(r'(\nChoose|\nPlease select|\nPick|\nSelect)\s+the (correct|right|appropriate) (response|option|answer).*?Your answer can only include.*?$', '', cleaned_text, flags=re.DOTALL | re.IGNORECASE).strip() |
| cleaned_text = re.sub(r'\.?\s*Final answer should be.*?$', '', cleaned_text, flags=re.DOTALL | re.IGNORECASE).strip() |
| cleaned_text = re.sub(r'\.?\s*Your final answer should be formatted.*?points\.', '', cleaned_text, flags=re.DOTALL | re.IGNORECASE).strip() |
| cleaned_text = re.sub(r'\.?\s*Please use the world coordinate system.*?objects\.', '', cleaned_text, flags=re.DOTALL | re.IGNORECASE).strip() |
| cleaned_text = cleaned_text.rstrip('?.') |
| return cleaned_text.strip() |
|
|
| |
| def analyze_ft_intra_similarity(metadata_dir, metadata_files, additional_files, model_name, num_iterations, keep_instructions): |
| """ |
| Calculates intra-task similarity for all specified fine-tuning datasets. |
| """ |
| print(f"Using Sentence Transformer model: {model_name}") |
| print(f"Keep instructions: {keep_instructions}") |
| print(f"Running {num_iterations} iterations...\n") |
| |
| try: |
| model = SentenceTransformer(model_name) |
| except Exception as e: |
| print(f"!!! ERROR loading model '{model_name}': {e}") |
| return None, None, None |
|
|
| |
| print("Processing Fine-tuning source data...") |
| ft_questions = {} |
| metadata_task_names = [] |
| additional_task_names = [] |
|
|
| |
| print(f"Reading {len(metadata_files)} metadata files from '{metadata_dir}'...") |
| for meta_filename in metadata_files: |
| meta_file_path = os.path.join(metadata_dir, meta_filename) |
| if not os.path.exists(meta_file_path): |
| print(f"Warning: Meta file not found: {meta_file_path}") |
| continue |
| try: |
| with open(meta_file_path, 'r', encoding='utf-8') as f: |
| metadata = json.load(f) |
| |
| if 'source_files' in metadata: |
| for src in metadata['source_files']: |
| task_name = src.get('file_path') |
| if not task_name: continue |
| |
| full_path = os.path.normpath(os.path.join("..", task_name)) |
| metadata_task_names.append(task_name) |
| if task_name not in ft_questions: |
| ft_questions[task_name] = [] |
| |
| if not os.path.exists(full_path): |
| print(f"Warning: Source file not found: {full_path}") |
| continue |
| |
| |
| with open(full_path, 'r', encoding='utf-8') as f_data: |
| data = json.load(f_data) |
| for item in data: |
| q = None |
| if 'conversations' in item and item['conversations']: q = item['conversations'][0].get('value') |
| elif 'question' in item: q = item.get('question') |
| cleaned_q = clean_question(q, keep_instructions=keep_instructions) |
| if cleaned_q: ft_questions[task_name].append(cleaned_q) |
|
|
| except Exception as e: |
| print(f"Warning: Could not process meta file {meta_filename}: {e}") |
|
|
| |
| print(f"Adding {len(additional_files)} additional FT files...") |
| for task_name, relative_path in additional_files.items(): |
| full_path = os.path.normpath(relative_path) |
| additional_task_names.append(task_name) |
| if task_name not in ft_questions: |
| ft_questions[task_name] = [] |
| |
| if not os.path.exists(full_path): |
| print(f"Warning: Additional file not found: {full_path}") |
| continue |
| |
| try: |
| |
| with open(full_path, 'r', encoding='utf-8') as f_data: |
| data = json.load(f_data) |
| for item in data: |
| q = None |
| if 'conversations' in item and item['conversations']: q = item['conversations'][0].get('value') |
| elif 'question' in item: q = item.get('question') |
| cleaned_q = clean_question(q, keep_instructions=keep_instructions) |
| if cleaned_q: ft_questions[task_name].append(cleaned_q) |
| except Exception as e: |
| print(f"Warning: Could not process additional file {full_path}: {e}") |
|
|
| print(f"Loaded questions from {len(ft_questions)} unique FT tasks.") |
|
|
| |
| print("\nCalculating all FT task embeddings (once)...") |
| ft_embeddings = {} |
| for q_type, questions in ft_questions.items(): |
| if len(questions) < 2: |
| print(f" - Skipping {q_type} (only {len(questions)} question(s))") |
| continue |
| print(f" - Encoding {q_type} ({len(questions)} q's)") |
| try: |
| ft_embeddings[q_type] = model.encode(questions, show_progress_bar=True) |
| except Exception as e: |
| print(f"!!! ERROR encoding {q_type}: {e}") |
| |
| if not ft_embeddings: |
| print("!!! ERROR: No tasks with sufficient questions to encode.") |
| return None, None, None |
|
|
| print("\nStarting similarity iterations...") |
| task_similarity_scores = {q_type: [] for q_type in ft_embeddings} |
|
|
| for i in range(num_iterations): |
| if (i + 1) % 10 == 0 or i == 0: |
| print(f" Iteration {i+1}/{num_iterations}...") |
| |
| for q_type, embeddings_list in ft_embeddings.items(): |
| |
| total_count = len(embeddings_list) |
| |
| k = min(total_count, MAX_SAMPLES_PER_ITERATION) |
| |
| sampled_indices = random.sample(range(total_count), k) |
| |
| mid_point = k // 2 |
| group_a_indices = sampled_indices[:mid_point] |
| group_b_indices = sampled_indices[mid_point:] |
|
|
| group_a_embeds = [embeddings_list[idx] for idx in group_a_indices] |
| group_b_embeds = [embeddings_list[idx] for idx in group_b_indices] |
| |
| avg_a = np.mean(group_a_embeds, axis=0).reshape(1, -1) |
| avg_b = np.mean(group_b_embeds, axis=0).reshape(1, -1) |
| |
| similarity = cosine_similarity(avg_a, avg_b)[0][0] |
| |
| task_similarity_scores[q_type].append(float(similarity)) |
| |
| print("Iterations complete.\n") |
| |
| |
| final_stats = {} |
| all_scores = [] |
| |
| for q_type, scores in task_similarity_scores.items(): |
| if scores: |
| mean_sim = np.mean(scores) |
| std_sim = np.std(scores) |
| final_stats[q_type] = { |
| 'mean': mean_sim, |
| 'std': std_sim, |
| 'iterations': len(scores), |
| 'num_questions': len(ft_embeddings[q_type]) |
| } |
| all_scores.extend(scores) |
|
|
| |
| if all_scores: |
| overall_mean = np.mean(all_scores) |
| overall_std = np.std(all_scores) |
| final_stats['--OVERALL--'] = { |
| 'mean': overall_mean, |
| 'std': overall_std, |
| 'iterations': num_iterations, |
| 'num_questions': 'N/A' |
| } |
|
|
| |
| return final_stats, metadata_task_names, additional_task_names |
|
|
| |
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser(description="Calculate intra-task similarity for all Fine-Tuning source datasets.") |
| parser.add_argument('--keep_instructions', action='store_true', |
| help="Keep instruction text in questions during embedding.") |
| parser.add_argument('--num_iterations', type=int, default=DEFAULT_NUM_ITERATIONS, |
| help=f"Number of random splits to perform (default: {DEFAULT_NUM_ITERATIONS}).") |
| args = parser.parse_args() |
|
|
| if not METADATA_BASE_DIR: |
| print("!!! ERROR: Please fill in the METADATA_BASE_DIR variable in the script.") |
| sys.exit() |
| |
| stats, metadata_tasks, additional_tasks = analyze_ft_intra_similarity( |
| metadata_dir=METADATA_BASE_DIR, |
| metadata_files=METADATA_FILENAMES, |
| additional_files=ADDITIONAL_FT_FILES, |
| model_name=MODEL_NAME, |
| num_iterations=args.num_iterations, |
| keep_instructions=args.keep_instructions |
| ) |
| |
| if stats: |
| print("--- Fine-Tuning Intra-Task Similarity Report ---") |
| print(f"(Based on {args.num_iterations} random split iterations)\n") |
|
|
| valid_tasks = [task for task in stats if task != '--OVERALL--'] |
| if not valid_tasks: |
| print("No valid tasks with scores found.") |
| sys.exit() |
| |
| max_name_len = max(len(t) for t in valid_tasks) |
| max_name_len = max(35, max_name_len) |
| |
| header = f"{'Task Name':<{max_name_len}} | {'Questions':>10} | {'Mean Similarity':>18} | {'Std. Deviation':>18}" |
| print(header) |
| print("-" * len(header)) |
|
|
| |
| print(f"\n--- Additional FT Tasks (e.g., RefSpatial) ---") |
| additional_tasks_found = 0 |
| for task in sorted(list(set(additional_tasks))): |
| if task in stats: |
| data = stats[task] |
| print(f"{task:<{max_name_len}} | {data['num_questions']:>10} | {data['mean']:>18.4f} | {data['std']:>18.4f}") |
| additional_tasks_found += 1 |
| if additional_tasks_found == 0: |
| print("No valid results for Additional FT Tasks.") |
|
|
| |
| print(f"\n--- Metadata-Discovered Tasks (e.g., VQA, SPAR) ---") |
| metadata_tasks_found = 0 |
| for task in sorted(list(set(metadata_tasks))): |
| if task in stats: |
| data = stats[task] |
| print(f"{task:<{max_name_len}} | {data['num_questions']:>10} | {data['mean']:>18.4f} | {data['std']:>18.4f}") |
| metadata_tasks_found += 1 |
| if metadata_tasks_found == 0: |
| print("No valid results for Metadata-Discovered Tasks.") |
|
|
| |
| if '--OVERALL--' in stats: |
| data = stats['--OVERALL--'] |
| print("\n" + "-" * len(header)) |
| print(f"{'--OVERALL--':<{max_name_len}} | {data['num_questions']:>10} | {data['mean']:>18.4f} | {data['std']:>18.4f}") |
|
|
| try: |
| os.makedirs(OUTPUT_DIR, exist_ok=True) |
| suffix = "_with_instructions" if args.keep_instructions else "" |
| report_filename = f"fixed_ft_datasets_intra_task_report{suffix}.json" |
| report_save_path = os.path.join(OUTPUT_DIR, report_filename) |
| |
| report_data = { |
| "report_summary": { |
| "model_name": MODEL_NAME, |
| "num_iterations": args.num_iterations, |
| "keep_instructions": args.keep_instructions |
| }, |
| "overall_stats": stats.get('--OVERALL--', {}), |
| "additional_ft_tasks": {}, |
| "metadata_discovered_tasks": {} |
| } |
| |
| |
| for task in sorted(list(set(additional_tasks))): |
| if task in stats: |
| report_data["additional_ft_tasks"][task] = stats[task] |
| |
| |
| for task in sorted(list(set(metadata_tasks))): |
| if task in stats: |
| report_data["metadata_discovered_tasks"][task] = stats[task] |
|
|
| with open(report_save_path, 'w', encoding='utf-8') as f: |
| json.dump(report_data, f, indent=2) |
| |
| print(f"\nStructured summary report saved to {report_save_path}") |
| |
| except Exception as e: |
| print(f"\n!!! Could not save final summary report to JSON: {e}") |
|
|
| print("\n--- Analysis finished ---") |