| import json |
| import os |
| import re |
| import pandas as pd |
| from collections import Counter |
| import argparse |
| import matplotlib.pyplot as plt |
| import seaborn as sns |
| import numpy as np |
| import colorsys |
| import random |
|
|
| |
| BASE_DATA_DIR = '.' |
| BASE_METADATA_DIR = 'metadata' |
| |
| OUTPUT_DIR = 'analyzed_results' |
| ERQA_OUTPUT_DIR = os.path.join(OUTPUT_DIR, 'ERQA') |
|
|
| DATASET_CATEGORIES = { |
| "SYNTHETIC": "synthetic_SAT_Spatial457_PRISM_80.0k", |
| "REAL": "real_SPAR-7M_RoboSpatial_80.0k", |
| "STATIC": "static_RoboSpatial_PRISM_Spatial457_SAT_SPAR-7M_80.0k", |
| "DYNAMIC": "dynamic_Spatial457_SAT_SPAR-7M_80k", |
| "PERCEPTION": "perception_Spatial457_SAT_SPAR-7M_80k", |
| "REASONING": "reasoning_RoboSpatial_PRISM_Spatial457_SAT_SPAR-7M_80.0k", |
| "_2D": "2d_Spatial457_SAT_SPAR-7M_80k", |
| "_3D": "3d_RoboSpatial_PRISM_Spatial457_SAT_SPAR-7M_80.0k", |
| } |
|
|
| |
| ERQA_TASK_MAPPING_FROM_FILENAME = { |
| "Spatial Reasoning": [ |
| 'relative_spatial', |
| 'relative_depth', |
| 'obj_spatial_relation', |
| 'robospatial_configuration', |
| 'Spatial457_all', |
| 'static', |
| 'dynamic', |
| '2D', |
| '3D', |
| ], |
| "Action Reasoning": [ |
| 'action_consequence', |
| 'robospatial_compatibility', |
| 'point_prediction', |
| ], |
| "Trajectory Reasoning": [ |
| 'egocentric_movement', |
| 'object_movement', |
| 'goal_aiming', |
| 'dynamic', |
| ], |
| "State Estimation": [ |
| 'count', |
| 'perception', |
| 'relative_spatial' |
| |
| ], |
| "Task Reasoning": [ |
| 'reasoning', |
| 'robospatial_compatibility' |
| |
| ], |
| "Multi-view Reasoning": [ |
| 'allocentric_perspective', |
| 'spatial_imagination', |
| 'obj_spatial_relation' |
| |
| |
| ], |
| "Pointing": [ |
| 'point_prediction', |
| ], |
| "Other": [ |
| |
| |
| ] |
| } |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| ERQA_COLORS = { |
| "Spatial Reasoning": "#1f77b4", |
| "Action Reasoning": "#ff7f0e", |
| "Trajectory Reasoning": "#2ca02c", |
| "State Estimation": "#d62728", |
| "Task Reasoning": "#9467bd", |
| "Multi-view Reasoning": "#8c564b", |
| "Pointing": "#e377c2", |
| "Other": "#17becf" |
| } |
|
|
| |
|
|
| def categorize_answer(answer_str): |
| |
| if not isinstance(answer_str, str): return "not_a_string" |
| parsed_answer = answer_str |
| if answer_str.startswith('{"Reasoning":'): |
| try: data = json.loads(answer_str); parsed_answer = str(data.get('Answer', answer_str)).strip() |
| except json.JSONDecodeError: parsed_answer = answer_str |
| parsed_answer = parsed_answer.strip() |
| if "<point x=" in parsed_answer: return "point_prediction_prism" |
| if parsed_answer.startswith('[(') and parsed_answer.endswith(')]'): return "point_list_robospatial" |
| if parsed_answer.lower() in ['yes', 'no']: return "yes_no" |
| if parsed_answer.upper() in ['A', 'B', 'C', 'D'] and len(parsed_answer) == 1: return "multiple_choice_letter_spar7m" |
| if re.fullmatch(r"^\d+$", parsed_answer) and len(parsed_answer)<4 : return "number_only" |
| if parsed_answer.lower() in ['left', 'right', 'above', 'below', 'front', 'behind', 'back']: return "simple_direction_sat" |
| if 'no objects moved' in parsed_answer or 'was moved' in parsed_answer: return "object_movement_sat" |
| if "From the observer's perspective" in parsed_answer or "the Object" in parsed_answer: return "descriptive_sentence_spar7m" |
| if len(parsed_answer) < 50: return "short_answer_other" |
| return "long_descriptive_other" |
|
|
| def map_filename_to_task(file_name, dataset_name): |
| |
| base_name, _ = os.path.splitext(file_name) |
| if 'obj_count' in base_name or base_name == 'count': return 'count' |
| elif base_name == 'robospatial_context': return 'point_prediction' |
| elif base_name == 'train_data' and dataset_name == 'PRISM': return 'point_prediction' |
| elif 'obj_spatial_relation' in base_name: return 'obj_spatial_relation' |
| elif 'spatial_imagination' in base_name: return 'spatial_imagination' |
| elif base_name.endswith('_tasks'): return base_name.replace('_tasks', '') |
| elif base_name == 'Spatial457_all': return 'Spatial457_all' |
| else: return base_name |
|
|
| def generate_color_variations(base_hex, num_variations, min_sat=0.3, max_sat=0.9, min_val=0.45, max_val=0.9): |
| |
| is_gray = base_hex.lower() in ["#808080", "#7f7f7f", "gray", "grey"] |
| if is_gray: min_sat=0.0; max_sat=0.05; min_val=0.35; max_val=0.75 |
| base_rgb = tuple(int(base_hex.lstrip('#')[i:i+2], 16) / 255.0 for i in (0, 2, 4)) |
| base_h, base_l, base_s = colorsys.rgb_to_hls(*base_rgb) |
| variations = [] |
| for i in range(num_variations): |
| l_ratio = i / max(1, num_variations - 1) if num_variations > 1 else 0.5 |
| s_ratio = i / max(1, num_variations - 1) if num_variations > 1 else 0.5 |
| l = min_val + (max_val - min_val) * l_ratio |
| s = min_sat + (max_sat - min_sat) * s_ratio |
| l = max(0.0, min(1.0, l)); s = max(0.0, min(1.0, s)) |
| r, g, b = colorsys.hls_to_rgb(base_h, l, s) |
| variations.append(f"#{int(r*255):02x}{int(g*255):02x}{int(b*255):02x}") |
| if num_variations > 1 : random.shuffle(variations) |
| return variations |
|
|
| |
|
|
| def analyze_answer_formats(): |
| |
| print("--- Starting Answer Format Analysis ---") |
| all_results = {} |
| for name, file_prefix in DATASET_CATEGORIES.items(): |
| file_path = os.path.join(BASE_DATA_DIR, f"{file_prefix}.json") |
| if not os.path.exists(file_path): all_results[name] = {"status": "File Not Found"}; continue |
| gpt_answers = [] |
| try: |
| with open(file_path, 'r', encoding='utf-8') as f: data = json.load(f) |
| for item in data: |
| convs = item.get('conversations', []) |
| for conv in convs: |
| if conv.get('from') == 'gpt': gpt_answers.append(conv.get('value')); break |
| if not gpt_answers: all_results[name] = {"status": "No 'gpt' answers"}; continue |
| total = len(gpt_answers) |
| cats = [categorize_answer(ans) for ans in gpt_answers] |
| counts = Counter(cats) |
| percentages = {cat: round(count / total * 100, 2) for cat, count in counts.items()} |
| all_results[name] = {"status": "Success", "category_counts_percentages": percentages} |
| except Exception as e: all_results[name] = {"status": f"Error: {e}"} |
| return all_results |
|
|
| def analyze_question_tasks(): |
| |
| print("--- Starting Question Task Analysis ---") |
| all_results = {} |
| for name, file_prefix in DATASET_CATEGORIES.items(): |
| metadata_path = os.path.join(BASE_METADATA_DIR, f"{file_prefix}_metadata.json") |
| if not os.path.exists(metadata_path): all_results[name] = {"status": "Metadata Not Found"}; continue |
| try: |
| with open(metadata_path, 'r', encoding='utf-8') as f: metadata = json.load(f) |
| task_counts = Counter(); total_entries = 0 |
| if 'source_files' not in metadata: all_results[name] = {"status": "Invalid Metadata"}; continue |
| for src in metadata['source_files']: |
| fname, dname = src.get('file_name'), src.get('dataset_name') |
| entries = src.get('sampled_entries', 0) |
| if fname and dname: |
| task = map_filename_to_task(fname, dname) |
| task_counts[task] += entries; total_entries += entries |
| if total_entries == 0: all_results[name] = {"status": "No Sampled Entries"}; continue |
| percentages = {task: round(count / total_entries * 100, 2) for task, count in task_counts.items()} |
| all_results[name] = {"status": "Success", "task_counts_percentages": percentages} |
| except Exception as e: all_results[name] = {"status": f"Error: {e}"} |
| return all_results |
|
|
| |
|
|
| def plot_distribution(analysis_results, analysis_type, highlight_tasks=None, base_color=None, save_path=None): |
| """Generates and saves a stacked bar plot, optionally highlighting specific tasks.""" |
|
|
| print(f"\n--- Generating Plot for {analysis_type} Distribution ---") |
|
|
| plot_data = {}; all_categories = set() |
| data_key = "category_counts_percentages" if analysis_type == "answer_format" else "task_counts_percentages" |
|
|
| for dataset, info in analysis_results.items(): |
| if info['status'] == 'Success' and data_key in info: |
| plot_data[dataset] = info[data_key] |
| all_categories.update(info[data_key].keys()) |
| else: plot_data[dataset] = {} |
|
|
| df = pd.DataFrame(plot_data).fillna(0).T |
| if df.empty or len(all_categories) == 0: print("!!! ERROR: No data available for plotting."); return |
|
|
| plot_df = df[sorted(list(all_categories))] |
| num_total_categories = len(plot_df.columns) |
|
|
| |
| color_map = {} |
| if highlight_tasks and base_color: |
| target_tasks = [task for task in highlight_tasks if task in plot_df.columns] |
| other_tasks = [task for task in plot_df.columns if task not in target_tasks] |
|
|
| target_colors = generate_color_variations(base_color, len(target_tasks)) |
| gray_variations = generate_color_variations("#808080", len(other_tasks)) |
|
|
| for i, task in enumerate(target_tasks): color_map[task] = target_colors[i % len(target_colors)] |
| for i, task in enumerate(other_tasks): color_map[task] = gray_variations[i % len(gray_variations)] |
|
|
| |
| plot_df = plot_df[sorted(target_tasks) + sorted(other_tasks)] |
|
|
| else: |
| |
| category_totals = plot_df.sum().sort_values(ascending=False) |
| top_n = 20 |
| top_categories = category_totals.head(top_n).index.tolist() |
| other_categories = category_totals.iloc[top_n:].index.tolist() |
| if other_categories: |
| plot_df['Other'] = plot_df[other_categories].sum(axis=1) |
| plot_df = plot_df[top_categories + ['Other']] |
| else: plot_df = plot_df[top_categories] |
|
|
| plot_df = plot_df[sorted(plot_df.columns)] |
| num_colors = len(plot_df.columns) |
| default_colors = sns.color_palette("tab20", n_colors=num_colors) if num_colors <= 20 else sns.color_palette("husl", n_colors=num_colors) |
| color_map = {col: default_colors[i % len(default_colors)] for i, col in enumerate(plot_df.columns)} |
|
|
| plot_colors = [color_map[col] for col in plot_df.columns] |
|
|
| |
| plt.style.use('seaborn-v0_8-whitegrid') |
| fig, ax = plt.subplots(figsize=(14, 8)) |
| plot_df.plot(kind='bar', stacked=True, ax=ax, color=plot_colors, width=0.8) |
|
|
| |
| ax.set_xlabel("Dataset Category", fontsize=12, fontweight='bold') |
| ax.set_ylabel("Percentage Distribution (%)", fontsize=12, fontweight='bold') |
| |
| title_suffix = f" (Highlighting: {os.path.basename(save_path).replace('.png','')})" if highlight_tasks else "" |
| title = f"Distribution of {analysis_type.replace('_', ' ').title()} Across Datasets{title_suffix}" |
| ax.set_title(title, fontsize=16, fontweight='bold', pad=20) |
| ax.tick_params(axis='x', rotation=45, labelsize=11); ax.tick_params(axis='y', labelsize=11) |
| ax.set_ylim(0, 100); ax.yaxis.grid(True, linestyle='--', alpha=0.7); ax.xaxis.grid(False) |
|
|
| |
| legend_title = f"{analysis_type.replace('_', ' ').title()} Categories" |
| ax.legend(title=legend_title, bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0., fontsize='small', title_fontsize='medium') |
|
|
| plt.tight_layout(rect=[0, 0, 0.83, 1]) |
|
|
| |
| plt.savefig(save_path, dpi=300, bbox_inches='tight') |
| print(f"Plot saved as {save_path}") |
| plt.close(fig) |
|
|
| |
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser(description="Analyze fine-tuning dataset distributions.") |
| group = parser.add_mutually_exclusive_group(required=True) |
| group.add_argument( |
| "--analysis_type", type=str, choices=['answer', 'question'], |
| help="Type of analysis: 'answer' (formats) or 'question' (tasks)." |
| ) |
| group.add_argument( |
| "--ERQA", action='store_true', |
| help="Generate individual plots highlighting tasks relevant to each ERQA category." |
| ) |
| args = parser.parse_args() |
|
|
| |
| os.makedirs(OUTPUT_DIR, exist_ok=True) |
| os.makedirs(ERQA_OUTPUT_DIR, exist_ok=True) |
|
|
| if args.analysis_type == 'answer': |
| results = analyze_answer_formats() |
| |
| save_path = os.path.join(OUTPUT_DIR, "answer_format_distribution.png") |
| plot_distribution(results, 'answer_format', save_path=save_path) |
|
|
| elif args.analysis_type == 'question': |
| results = analyze_question_tasks() |
| |
| save_path = os.path.join(OUTPUT_DIR, "question_task_distribution.png") |
| plot_distribution(results, 'question_task', save_path=save_path) |
|
|
| elif args.ERQA: |
| print("--- Starting ERQA Task Highlighting Analysis ---") |
| question_task_results = analyze_question_tasks() |
|
|
| if not question_task_results or all(v['status'] != 'Success' for v in question_task_results.values()): |
| print("!!! ERROR: Could not get valid question task data for ERQA plotting.") |
| else: |
| for erqa_task_name, relevant_ft_tasks in ERQA_TASK_MAPPING_FROM_FILENAME.items(): |
| if not relevant_ft_tasks and erqa_task_name != "Other": continue |
|
|
| print(f"\n--- Generating plot for ERQA Task: {erqa_task_name} ---") |
| base_color = ERQA_COLORS.get(erqa_task_name, "#7f7f7f") |
|
|
| |
| erqa_save_filename = f"{erqa_task_name.replace(' ', '_')}.png" |
| |
| erqa_save_path = os.path.join(ERQA_OUTPUT_DIR, erqa_save_filename) |
|
|
| |
| plot_distribution( |
| analysis_results=question_task_results, |
| analysis_type='question_task', |
| highlight_tasks=relevant_ft_tasks, |
| base_color=base_color, |
| save_path=erqa_save_path |
| ) |
|
|
| print("\n--- Analysis finished ---") |