Fine-tuning-data / calculate_embedding_within_each_task.py
ch-min's picture
Upload folder using huggingface_hub
849ca03 verified
import json
import os
import re
import numpy as np
import argparse
from sentence_transformers import SentenceTransformer # pip install sentence-transformers
from sklearn.metrics.pairwise import cosine_similarity # pip install scikit-learn
import random
import sys
# --- 1. Configuration Settings (from original script) ---
# TODO: Fill in the directory containing the 8 metadata JSON files
METADATA_BASE_DIR = 'metadata' # Directory containing these files
# TODO: Fill in the directory containing the 8 metadata JSON files
OUTPUT_DIR = "intra_task"
# --- Hardcoded list of metadata filenames ---
METADATA_FILENAMES = [
# '2d_Spatial457_SAT_SPAR-7M_80k_metadata.json',
# '3d_RoboSpatial_PRISM_Spatial457_SAT_SPAR-7M_80.0k_metadata.json',
# 'dynamic_Spatial457_SAT_SPAR-7M_80k_metadata.json',
# 'perception_Spatial457_SAT_SPAR-7M_80k_metadata.json',
# 'real_SPAR-7M_RoboSpatial_80.0k_metadata.json',
# 'reasoning_RoboSpatial_PRISM_Spatial457_SAT_SPAR-7M_80.0k_metadata.json',
# 'static_RoboSpatial_PRISM_Spatial457_SAT_SPAR-7M_80.0k_metadata.json',
# 'synthetic_SAT_Spatial457_PRISM_80.0k_metadata.json'
]
# --- FT files (hardcoded list of FT data) ---
ADDITIONAL_FT_FILES = {
# RefSpatial
"RefSpatial_2D_choice_qa": "../RefSpatial_data/2D/choice_qa.json",
"RefSpatial_2D_reasoning_template_qa": "../RefSpatial_data/2D/reasoning_template_qa.json",
"RefSpatial_3D_choice_qa": "../RefSpatial_data/3D/choice_qa.json",
"RefSpatial_3D_multi_view_qa": "../RefSpatial_data/3D/multi_view_qa.json",
"RefSpatial_3D_reasoning_template_qa": "../RefSpatial_data/3D/reasoning_template_qa.json",
"RefSpatial_3D_vacant_qa": "../RefSpatial_data/3D/vacant_qa.json",
"RefSpatial_3D_visual_choice_qa": "../RefSpatial_data/3D/visual_choice_qa.json",
# Spatial457
"Spatial457_L1_single": "../Spatial457_data/qwen_data_new/L1_single.json",
"Spatial457_L2_objects": "../Spatial457_data/qwen_data_new/L2_objects.json",
"Spatial457_L3_2d_spatial": "../Spatial457_data/qwen_data_new/L3_2d_spatial.json",
"Spatial457_L4_occ": "../Spatial457_data/qwen_data_new/L4_occ.json",
"Spatial457_L4_pose": "../Spatial457_data/qwen_data_new/L4_pose.json",
"Spatial457_L5_6d_spatial": "../Spatial457_data/qwen_data_new/L5_6d_spatial.json",
"Spatial457_L5_collision": "../Spatial457_data/qwen_data_new/L5_collision.json",
# SPAR-7M
"SPAR-7M_obj_count": "../SPAR-7M_data/qwen_data/obj_count.json",
"SPAR-7M_obj_spatial_relation": "../SPAR-7M_data/qwen_data/obj_spatial_relation.json",
"SPAR-7M_spatial_imagination": "../SPAR-7M_data/qwen_data/spatial_imagination.json",
# SAT
"SAT_action_consequence": "../SAT_data/qwen_data_new/action_consequence.json",
"SAT_action_sequence": "../SAT_data/qwen_data_new/action_sequence.json",
"SAT_goal_aim": "../SAT_data/qwen_data_new/goal_aim.json",
"SAT_obj_movement": "../SAT_data/qwen_data_new/obj_movement.json",
"SAT_other": "../SAT_data/qwen_data_new/other.json",
"SAT_perspective": "../SAT_data/qwen_data_new/perspective.json",
# PRISM
"PRISM_train_data": "../PRISM_data/qwen_data/train_data.json",
# RoboSpatial
"RoboSpatial_compatibility": "../RoboSpatial_data/qwen_data/robospatial_compatibility.json",
"RoboSpatial_configuration": "../RoboSpatial_data/qwen_data/robospatial_configuration.json",
"RoboSpatial_context": "../RoboSpatial_data/qwen_data/robospatial_context.json",
}
# --- Analysis Configuration ---
MODEL_NAME = 'sentence-transformers/all-mpnet-base-v2'
DEFAULT_NUM_ITERATIONS = 20
MAX_SAMPLES_PER_ITERATION = 10000
# --- 2. Helper Function (from original script) ---
def clean_question(text, keep_instructions=False):
"""Cleans question text, optionally removing instructions."""
if not isinstance(text, str): return ""
cleaned_text = re.sub(r'^<image>\n?', '', text).strip()
# If we keep instructions, just do the basic clean and return
if keep_instructions:
return cleaned_text
# If we remove instructions, apply all regex rules
cleaned_text = re.sub(r'Choices:.*?(?=\nPlease answer|\n|$)', '', cleaned_text, flags=re.DOTALL | re.IGNORECASE).strip()
cleaned_text = re.sub(r'\n\nChoices:\s*.*?(?=\n|$)', '', cleaned_text, flags=re.DOTALL | re.IGNORECASE).strip()
cleaned_text = re.sub(r'\nPlease answer directly.*', '', cleaned_text, flags=re.DOTALL | re.IGNORECASE).strip()
cleaned_text = re.sub(r'The options describe the spatial relationship.*?(\nChoose|\nPlease select|\nPick|\nSelect)', '', cleaned_text, flags=re.DOTALL | re.IGNORECASE).strip()
cleaned_text = re.sub(r'(\nChoose|\nPlease select|\nPick|\nSelect)\s+the (correct|right|appropriate) (response|option|answer).*?Your answer can only include.*?$', '', cleaned_text, flags=re.DOTALL | re.IGNORECASE).strip()
cleaned_text = re.sub(r'\.?\s*Final answer should be.*?$', '', cleaned_text, flags=re.DOTALL | re.IGNORECASE).strip()
cleaned_text = re.sub(r'\.?\s*Your final answer should be formatted.*?points\.', '', cleaned_text, flags=re.DOTALL | re.IGNORECASE).strip()
cleaned_text = re.sub(r'\.?\s*Please use the world coordinate system.*?objects\.', '', cleaned_text, flags=re.DOTALL | re.IGNORECASE).strip()
cleaned_text = cleaned_text.rstrip('?.')
return cleaned_text.strip()
# --- 3. Main Analysis Function ---
def analyze_ft_intra_similarity(metadata_dir, metadata_files, additional_files, model_name, num_iterations, keep_instructions):
"""
Calculates intra-task similarity for all specified fine-tuning datasets.
"""
print(f"Using Sentence Transformer model: {model_name}")
print(f"Keep instructions: {keep_instructions}")
print(f"Running {num_iterations} iterations...\n")
try:
model = SentenceTransformer(model_name)
except Exception as e:
print(f"!!! ERROR loading model '{model_name}': {e}")
return None, None, None
# --- Load FT Data ---
print("Processing Fine-tuning source data...")
ft_questions = {}
metadata_task_names = []
additional_task_names = []
# 1. Load from metadata
print(f"Reading {len(metadata_files)} metadata files from '{metadata_dir}'...")
for meta_filename in metadata_files:
meta_file_path = os.path.join(metadata_dir, meta_filename)
if not os.path.exists(meta_file_path):
print(f"Warning: Meta file not found: {meta_file_path}")
continue
try:
with open(meta_file_path, 'r', encoding='utf-8') as f:
metadata = json.load(f)
if 'source_files' in metadata:
for src in metadata['source_files']:
task_name = src.get('file_path') # Use relative path as task name
if not task_name: continue
full_path = os.path.normpath(os.path.join("..", task_name))
metadata_task_names.append(task_name)
if task_name not in ft_questions:
ft_questions[task_name] = []
if not os.path.exists(full_path):
print(f"Warning: Source file not found: {full_path}")
continue
# Load questions from this source file
with open(full_path, 'r', encoding='utf-8') as f_data:
data = json.load(f_data)
for item in data:
q = None
if 'conversations' in item and item['conversations']: q = item['conversations'][0].get('value')
elif 'question' in item: q = item.get('question')
cleaned_q = clean_question(q, keep_instructions=keep_instructions)
if cleaned_q: ft_questions[task_name].append(cleaned_q)
except Exception as e:
print(f"Warning: Could not process meta file {meta_filename}: {e}")
# 2. Add additional files
print(f"Adding {len(additional_files)} additional FT files...")
for task_name, relative_path in additional_files.items():
full_path = os.path.normpath(relative_path)
additional_task_names.append(task_name)
if task_name not in ft_questions:
ft_questions[task_name] = []
if not os.path.exists(full_path):
print(f"Warning: Additional file not found: {full_path}")
continue
try:
# Load questions from this source file
with open(full_path, 'r', encoding='utf-8') as f_data:
data = json.load(f_data)
for item in data:
q = None
if 'conversations' in item and item['conversations']: q = item['conversations'][0].get('value')
elif 'question' in item: q = item.get('question')
cleaned_q = clean_question(q, keep_instructions=keep_instructions)
if cleaned_q: ft_questions[task_name].append(cleaned_q)
except Exception as e:
print(f"Warning: Could not process additional file {full_path}: {e}")
print(f"Loaded questions from {len(ft_questions)} unique FT tasks.")
# --- Pre-calculate all embeddings ---
print("\nCalculating all FT task embeddings (once)...")
ft_embeddings = {}
for q_type, questions in ft_questions.items():
if len(questions) < 2:
print(f" - Skipping {q_type} (only {len(questions)} question(s))")
continue
print(f" - Encoding {q_type} ({len(questions)} q's)")
try:
ft_embeddings[q_type] = model.encode(questions, show_progress_bar=True)
except Exception as e:
print(f"!!! ERROR encoding {q_type}: {e}")
if not ft_embeddings:
print("!!! ERROR: No tasks with sufficient questions to encode.")
return None, None, None
print("\nStarting similarity iterations...")
task_similarity_scores = {q_type: [] for q_type in ft_embeddings}
for i in range(num_iterations):
if (i + 1) % 10 == 0 or i == 0:
print(f" Iteration {i+1}/{num_iterations}...")
for q_type, embeddings_list in ft_embeddings.items():
total_count = len(embeddings_list)
k = min(total_count, MAX_SAMPLES_PER_ITERATION)
sampled_indices = random.sample(range(total_count), k)
mid_point = k // 2
group_a_indices = sampled_indices[:mid_point]
group_b_indices = sampled_indices[mid_point:]
group_a_embeds = [embeddings_list[idx] for idx in group_a_indices]
group_b_embeds = [embeddings_list[idx] for idx in group_b_indices]
avg_a = np.mean(group_a_embeds, axis=0).reshape(1, -1)
avg_b = np.mean(group_b_embeds, axis=0).reshape(1, -1)
similarity = cosine_similarity(avg_a, avg_b)[0][0]
# --- FIX: Cast numpy.float32 to a standard python float ---
task_similarity_scores[q_type].append(float(similarity))
print("Iterations complete.\n")
# --- Calculate Final Statistics ---
final_stats = {}
all_scores = []
for q_type, scores in task_similarity_scores.items():
if scores:
mean_sim = np.mean(scores)
std_sim = np.std(scores)
final_stats[q_type] = {
'mean': mean_sim,
'std': std_sim,
'iterations': len(scores),
'num_questions': len(ft_embeddings[q_type])
}
all_scores.extend(scores)
# Calculate overall average
if all_scores:
overall_mean = np.mean(all_scores)
overall_std = np.std(all_scores)
final_stats['--OVERALL--'] = {
'mean': overall_mean,
'std': overall_std,
'iterations': num_iterations,
'num_questions': 'N/A'
}
# Return stats and the lists of task names for grouped printing
return final_stats, metadata_task_names, additional_task_names
# --- 4. Main Execution Block ---
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Calculate intra-task similarity for all Fine-Tuning source datasets.")
parser.add_argument('--keep_instructions', action='store_true',
help="Keep instruction text in questions during embedding.")
parser.add_argument('--num_iterations', type=int, default=DEFAULT_NUM_ITERATIONS,
help=f"Number of random splits to perform (default: {DEFAULT_NUM_ITERATIONS}).")
args = parser.parse_args()
if not METADATA_BASE_DIR:
print("!!! ERROR: Please fill in the METADATA_BASE_DIR variable in the script.")
sys.exit()
stats, metadata_tasks, additional_tasks = analyze_ft_intra_similarity(
metadata_dir=METADATA_BASE_DIR,
metadata_files=METADATA_FILENAMES,
additional_files=ADDITIONAL_FT_FILES,
model_name=MODEL_NAME,
num_iterations=args.num_iterations,
keep_instructions=args.keep_instructions
)
if stats:
print("--- Fine-Tuning Intra-Task Similarity Report ---")
print(f"(Based on {args.num_iterations} random split iterations)\n")
valid_tasks = [task for task in stats if task != '--OVERALL--']
if not valid_tasks:
print("No valid tasks with scores found.")
sys.exit()
max_name_len = max(len(t) for t in valid_tasks)
max_name_len = max(35, max_name_len)
header = f"{'Task Name':<{max_name_len}} | {'Questions':>10} | {'Mean Similarity':>18} | {'Std. Deviation':>18}"
print(header)
print("-" * len(header))
# --- (Console Print) Group 1: Additional FT Tasks ---
print(f"\n--- Additional FT Tasks (e.g., RefSpatial) ---")
additional_tasks_found = 0
for task in sorted(list(set(additional_tasks))):
if task in stats:
data = stats[task]
print(f"{task:<{max_name_len}} | {data['num_questions']:>10} | {data['mean']:>18.4f} | {data['std']:>18.4f}")
additional_tasks_found += 1
if additional_tasks_found == 0:
print("No valid results for Additional FT Tasks.")
# --- (Console Print) Group 2: Metadata-Discovered Tasks ---
print(f"\n--- Metadata-Discovered Tasks (e.g., VQA, SPAR) ---")
metadata_tasks_found = 0
for task in sorted(list(set(metadata_tasks))):
if task in stats:
data = stats[task]
print(f"{task:<{max_name_len}} | {data['num_questions']:>10} | {data['mean']:>18.4f} | {data['std']:>18.4f}")
metadata_tasks_found += 1
if metadata_tasks_found == 0:
print("No valid results for Metadata-Discovered Tasks.")
# --- (Console Print) Overall Summary ---
if '--OVERALL--' in stats:
data = stats['--OVERALL--']
print("\n" + "-" * len(header))
print(f"{'--OVERALL--':<{max_name_len}} | {data['num_questions']:>10} | {data['mean']:>18.4f} | {data['std']:>18.4f}")
try:
os.makedirs(OUTPUT_DIR, exist_ok=True)
suffix = "_with_instructions" if args.keep_instructions else ""
report_filename = f"fixed_ft_datasets_intra_task_report{suffix}.json"
report_save_path = os.path.join(OUTPUT_DIR, report_filename)
report_data = {
"report_summary": {
"model_name": MODEL_NAME,
"num_iterations": args.num_iterations,
"keep_instructions": args.keep_instructions
},
"overall_stats": stats.get('--OVERALL--', {}),
"additional_ft_tasks": {},
"metadata_discovered_tasks": {}
}
# 1. Additional FT Tasks (RefSpatial, etc.)
for task in sorted(list(set(additional_tasks))):
if task in stats:
report_data["additional_ft_tasks"][task] = stats[task]
# 2. Metadata-Discovered Tasks (VQA, SPAR, etc.)
for task in sorted(list(set(metadata_tasks))):
if task in stats:
report_data["metadata_discovered_tasks"][task] = stats[task]
with open(report_save_path, 'w', encoding='utf-8') as f:
json.dump(report_data, f, indent=2)
print(f"\nStructured summary report saved to {report_save_path}")
except Exception as e:
print(f"\n!!! Could not save final summary report to JSON: {e}")
print("\n--- Analysis finished ---")