import json import os import re import pandas as pd from collections import Counter import argparse import matplotlib.pyplot as plt import seaborn as sns import numpy as np import colorsys # For creating color variations import random # --- 1. Configuration Settings --- BASE_DATA_DIR = '.' BASE_METADATA_DIR = 'metadata' # UPDATED: Output directory structure OUTPUT_DIR = 'analyzed_results' # Base output directory ERQA_OUTPUT_DIR = os.path.join(OUTPUT_DIR, 'ERQA') # Subdirectory for ERQA plots DATASET_CATEGORIES = { "SYNTHETIC": "synthetic_SAT_Spatial457_PRISM_80.0k", "REAL": "real_SPAR-7M_RoboSpatial_80.0k", "STATIC": "static_RoboSpatial_PRISM_Spatial457_SAT_SPAR-7M_80.0k", "DYNAMIC": "dynamic_Spatial457_SAT_SPAR-7M_80k", "PERCEPTION": "perception_Spatial457_SAT_SPAR-7M_80k", "REASONING": "reasoning_RoboSpatial_PRISM_Spatial457_SAT_SPAR-7M_80.0k", "_2D": "2d_Spatial457_SAT_SPAR-7M_80k", "_3D": "3d_RoboSpatial_PRISM_Spatial457_SAT_SPAR-7M_80.0k", } # --- 2. ERQA Task Mapping & Colors --- ERQA_TASK_MAPPING_FROM_FILENAME = { "Spatial Reasoning": [ 'relative_spatial', # SAT: Asks about left/right/above/below relationships. 'relative_depth', # SAT: Asks about closer/further, higher/lower relationships. 'obj_spatial_relation', # SPAR-7M: Asks about relative object locations (e.g., left, below, farther). 'robospatial_configuration',# RoboSpatial: Asks simple positional questions (e.g., A in front of B?). 'Spatial457_all', # Spatial457: Likely contains general spatial relationship questions. 'static', # Spatial457: Focuses on static scenes, implies spatial relationships. 'dynamic', '2D', # Spatial457: 2D spatial relationships. '3D', # Spatial457: 3D spatial relationships. ], "Action Reasoning": [ 'action_consequence', # SAT: Asks about the outcome of an action (e.g., facing away after turning?). 'robospatial_compatibility',# RoboSpatial: Asks if an action is possible (e.g., can X fit Y?). 'point_prediction', # PRISM: Asks for the grasp point *to accomplish a specific task*. ], "Trajectory Reasoning": [ 'egocentric_movement', # SAT: Asks how the camera moved/rotated. 'object_movement', # SAT: Asks if/how objects moved. 'goal_aiming', # SAT: Asks which way to turn to face a target. 'dynamic', # Spatial457: Likely involves predicting outcomes of movement (e.g., collision). ], "State Estimation": [ 'count', # SAT, SPAR-7M: Asks for the number of objects. 'perception', # Spatial457: Asks about object properties (color, shape, size). 'relative_spatial' # Note: Spatial457's 2D/3D/static/reasoning also contain property questions, reinforcing this link. ], "Task Reasoning": [ 'reasoning', # Spatial457: Contains broader reasoning Qs, some task-related (e.g., size comparison for collision). Fits ERQA's broad definition. 'robospatial_compatibility' # Note: action_consequence could also weakly fit here. ], "Multi-view Reasoning": [ 'allocentric_perspective', # SAT: Asks about spatial relationships from a different imagined viewpoint. 'spatial_imagination', # SPAR-7M: Asks how relationships change after observer movement. 'obj_spatial_relation' # Note: obj_spatial_relation might include multi-view variants based on original SPAR-7M file names, # but the grouped name 'obj_spatial_relation' doesn't specify view count. ], "Pointing": [ 'point_prediction', # RoboSpatial: Asks to pinpoint multiple points in a specified vacant area. ], "Other": [ # Fine-tuning tasks that don't clearly fit above. # Currently, all major tasks seem reasonably mapped. ] } # ERQA_TASK_MAPPING_FROM_FILENAME = { # "Spatial Reasoning": ['relative_spatial', 'relative_depth', 'obj_spatial_relation', 'robospatial_configuration', 'Spatial457_all', 'static', '2D', '3D'], # "Action Reasoning": ['action_consequence', 'robospatial_compatibility', 'point_prediction'], # Includes PRISM # "Trajectory Reasoning": ['egocentric_movement', 'object_movement', 'goal_aiming', 'dynamic'], # "State Estimation": ['count', 'perception'], # "Task Reasoning": ['reasoning'], # "Multi-view Reasoning": ['allocentric_perspective', 'spatial_imagination'], # "Pointing": ['point_prediction'], # Includes RoboSpatial context # "Other": [] # } # Define NEW representative colors for ERQA tasks ERQA_COLORS = { "Spatial Reasoning": "#1f77b4", # 선명한 파랑 (Blue) "Action Reasoning": "#ff7f0e", # 선명한 주황 (Orange) "Trajectory Reasoning": "#2ca02c", # 선명한 초록 (Green) "State Estimation": "#d62728", # 선명한 빨강 (Red) "Task Reasoning": "#9467bd", # 선명한 보라 (Purple) "Multi-view Reasoning": "#8c564b", # 갈색 (Brown) "Pointing": "#e377c2", # 분홍 (Pink) "Other": "#17becf" # 청록 (Teal/Cyan) } # --- 3. Helper Functions --- def categorize_answer(answer_str): # (Existing function - unchanged) if not isinstance(answer_str, str): return "not_a_string" parsed_answer = answer_str if answer_str.startswith('{"Reasoning":'): try: data = json.loads(answer_str); parsed_answer = str(data.get('Answer', answer_str)).strip() except json.JSONDecodeError: parsed_answer = answer_str parsed_answer = parsed_answer.strip() if " 1 else 0.5 s_ratio = i / max(1, num_variations - 1) if num_variations > 1 else 0.5 l = min_val + (max_val - min_val) * l_ratio s = min_sat + (max_sat - min_sat) * s_ratio l = max(0.0, min(1.0, l)); s = max(0.0, min(1.0, s)) r, g, b = colorsys.hls_to_rgb(base_h, l, s) variations.append(f"#{int(r*255):02x}{int(g*255):02x}{int(b*255):02x}") if num_variations > 1 : random.shuffle(variations) return variations # --- 4. Analysis Functions --- def analyze_answer_formats(): # (Existing function - unchanged) print("--- Starting Answer Format Analysis ---") all_results = {} for name, file_prefix in DATASET_CATEGORIES.items(): file_path = os.path.join(BASE_DATA_DIR, f"{file_prefix}.json") if not os.path.exists(file_path): all_results[name] = {"status": "File Not Found"}; continue gpt_answers = [] try: with open(file_path, 'r', encoding='utf-8') as f: data = json.load(f) for item in data: convs = item.get('conversations', []) for conv in convs: if conv.get('from') == 'gpt': gpt_answers.append(conv.get('value')); break if not gpt_answers: all_results[name] = {"status": "No 'gpt' answers"}; continue total = len(gpt_answers) cats = [categorize_answer(ans) for ans in gpt_answers] counts = Counter(cats) percentages = {cat: round(count / total * 100, 2) for cat, count in counts.items()} all_results[name] = {"status": "Success", "category_counts_percentages": percentages} except Exception as e: all_results[name] = {"status": f"Error: {e}"} return all_results def analyze_question_tasks(): # (Existing function - unchanged) print("--- Starting Question Task Analysis ---") all_results = {} for name, file_prefix in DATASET_CATEGORIES.items(): metadata_path = os.path.join(BASE_METADATA_DIR, f"{file_prefix}_metadata.json") if not os.path.exists(metadata_path): all_results[name] = {"status": "Metadata Not Found"}; continue try: with open(metadata_path, 'r', encoding='utf-8') as f: metadata = json.load(f) task_counts = Counter(); total_entries = 0 if 'source_files' not in metadata: all_results[name] = {"status": "Invalid Metadata"}; continue for src in metadata['source_files']: fname, dname = src.get('file_name'), src.get('dataset_name') entries = src.get('sampled_entries', 0) if fname and dname: task = map_filename_to_task(fname, dname) task_counts[task] += entries; total_entries += entries if total_entries == 0: all_results[name] = {"status": "No Sampled Entries"}; continue percentages = {task: round(count / total_entries * 100, 2) for task, count in task_counts.items()} all_results[name] = {"status": "Success", "task_counts_percentages": percentages} except Exception as e: all_results[name] = {"status": f"Error: {e}"} return all_results # --- 5. Plotting Function --- def plot_distribution(analysis_results, analysis_type, highlight_tasks=None, base_color=None, save_path=None): """Generates and saves a stacked bar plot, optionally highlighting specific tasks.""" print(f"\n--- Generating Plot for {analysis_type} Distribution ---") plot_data = {}; all_categories = set() data_key = "category_counts_percentages" if analysis_type == "answer_format" else "task_counts_percentages" for dataset, info in analysis_results.items(): if info['status'] == 'Success' and data_key in info: plot_data[dataset] = info[data_key] all_categories.update(info[data_key].keys()) else: plot_data[dataset] = {} df = pd.DataFrame(plot_data).fillna(0).T if df.empty or len(all_categories) == 0: print("!!! ERROR: No data available for plotting."); return plot_df = df[sorted(list(all_categories))] # Include all categories initially num_total_categories = len(plot_df.columns) # --- Generate Colors (Highlighting or Default) --- color_map = {} if highlight_tasks and base_color: target_tasks = [task for task in highlight_tasks if task in plot_df.columns] other_tasks = [task for task in plot_df.columns if task not in target_tasks] target_colors = generate_color_variations(base_color, len(target_tasks)) gray_variations = generate_color_variations("#808080", len(other_tasks)) # Use the updated gray generator for i, task in enumerate(target_tasks): color_map[task] = target_colors[i % len(target_colors)] for i, task in enumerate(other_tasks): color_map[task] = gray_variations[i % len(gray_variations)] # Order columns: highlighted first (sorted), then others (sorted) plot_df = plot_df[sorted(target_tasks) + sorted(other_tasks)] else: # Default plotting (no highlighting) # Apply Top N + Other logic ONLY for default plots category_totals = plot_df.sum().sort_values(ascending=False) top_n = 20 top_categories = category_totals.head(top_n).index.tolist() other_categories = category_totals.iloc[top_n:].index.tolist() if other_categories: plot_df['Other'] = plot_df[other_categories].sum(axis=1) plot_df = plot_df[top_categories + ['Other']] else: plot_df = plot_df[top_categories] plot_df = plot_df[sorted(plot_df.columns)] # Sort final columns num_colors = len(plot_df.columns) default_colors = sns.color_palette("tab20", n_colors=num_colors) if num_colors <= 20 else sns.color_palette("husl", n_colors=num_colors) color_map = {col: default_colors[i % len(default_colors)] for i, col in enumerate(plot_df.columns)} plot_colors = [color_map[col] for col in plot_df.columns] # --- Plotting --- plt.style.use('seaborn-v0_8-whitegrid') fig, ax = plt.subplots(figsize=(14, 8)) plot_df.plot(kind='bar', stacked=True, ax=ax, color=plot_colors, width=0.8) # --- Formatting --- ax.set_xlabel("Dataset Category", fontsize=12, fontweight='bold') ax.set_ylabel("Percentage Distribution (%)", fontsize=12, fontweight='bold') # Use basename of save_path for title suffix if highlighting title_suffix = f" (Highlighting: {os.path.basename(save_path).replace('.png','')})" if highlight_tasks else "" title = f"Distribution of {analysis_type.replace('_', ' ').title()} Across Datasets{title_suffix}" ax.set_title(title, fontsize=16, fontweight='bold', pad=20) ax.tick_params(axis='x', rotation=45, labelsize=11); ax.tick_params(axis='y', labelsize=11) ax.set_ylim(0, 100); ax.yaxis.grid(True, linestyle='--', alpha=0.7); ax.xaxis.grid(False) # --- Legend --- legend_title = f"{analysis_type.replace('_', ' ').title()} Categories" ax.legend(title=legend_title, bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0., fontsize='small', title_fontsize='medium') plt.tight_layout(rect=[0, 0, 0.83, 1]) # --- Save the plot --- plt.savefig(save_path, dpi=300, bbox_inches='tight') print(f"Plot saved as {save_path}") plt.close(fig) # --- 6. Main Execution Block --- if __name__ == "__main__": parser = argparse.ArgumentParser(description="Analyze fine-tuning dataset distributions.") group = parser.add_mutually_exclusive_group(required=True) group.add_argument( "--analysis_type", type=str, choices=['answer', 'question'], help="Type of analysis: 'answer' (formats) or 'question' (tasks)." ) group.add_argument( "--ERQA", action='store_true', help="Generate individual plots highlighting tasks relevant to each ERQA category." ) args = parser.parse_args() # Create output directories recursively os.makedirs(OUTPUT_DIR, exist_ok=True) os.makedirs(ERQA_OUTPUT_DIR, exist_ok=True) if args.analysis_type == 'answer': results = analyze_answer_formats() # UPDATED: Save path is now under analyzed_results/ save_path = os.path.join(OUTPUT_DIR, "answer_format_distribution.png") plot_distribution(results, 'answer_format', save_path=save_path) elif args.analysis_type == 'question': results = analyze_question_tasks() # UPDATED: Save path is now under analyzed_results/ save_path = os.path.join(OUTPUT_DIR, "question_task_distribution.png") plot_distribution(results, 'question_task', save_path=save_path) elif args.ERQA: print("--- Starting ERQA Task Highlighting Analysis ---") question_task_results = analyze_question_tasks() if not question_task_results or all(v['status'] != 'Success' for v in question_task_results.values()): print("!!! ERROR: Could not get valid question task data for ERQA plotting.") else: for erqa_task_name, relevant_ft_tasks in ERQA_TASK_MAPPING_FROM_FILENAME.items(): if not relevant_ft_tasks and erqa_task_name != "Other": continue # Skip ERQA Other print(f"\n--- Generating plot for ERQA Task: {erqa_task_name} ---") base_color = ERQA_COLORS.get(erqa_task_name, "#7f7f7f") # Default to gray # Define save path within ERQA subdirectory erqa_save_filename = f"{erqa_task_name.replace(' ', '_')}.png" # UPDATED: Save path uses ERQA_OUTPUT_DIR erqa_save_path = os.path.join(ERQA_OUTPUT_DIR, erqa_save_filename) # Call plot function with highlighting plot_distribution( analysis_results=question_task_results, analysis_type='question_task', highlight_tasks=relevant_ft_tasks, base_color=base_color, save_path=erqa_save_path ) print("\n--- Analysis finished ---")