import matplotlib.pyplot as plt import pandas as pd import numpy as np import seaborn as sns import os # Import os for path joining # Provided data (shortened for brevity, full data used in execution) # data = { # "SYNTHETIC": { # "status": "Success", # "category_counts": { # "point_prediction_prism": 32.85, "simple_direction_sat": 11.1, "yes_no": 8.51, "short_answer_other": 8.31, # "number_only": 5.37, "json_answer:small": 3.15, "json_answer:large": 3.14, "json_answer:no": 2.68, # "json_answer:yes": 2.6, "multiple_choice_letter_spar7m": 1.85, "object_movement_sat": 1.51, # "json_answer:1": 1.49, "json_answer:0": 1.2, "json_answer:2": 0.95, # ... other json_answers omitted # "long_descriptive_other": 0.82 # # ... other categories omitted for brevity in this display # } # }, # "REAL": { # "status": "Success", # "category_counts": { # "yes_no": 33.33, "multiple_choice_letter_spar7m": 22.7, "point_list_robospatial": 16.67, # "number_only": 12.29, "long_descriptive_other": 10.35, "descriptive_sentence_spar7m": 4.66 # } # }, # "STATIC": { # "status": "Success", # "category_counts": { # "point_prediction_prism": 20.0, "yes_no": 17.12, "multiple_choice_letter_spar7m": 13.26, # "simple_direction_sat": 7.77, "point_list_robospatial": 6.67, "long_descriptive_other": 6.14, # "number_only": 3.87, "short_answer_other": 2.68, "descriptive_sentence_spar7m": 2.51, # "json_answer:large": 2.15, "json_answer:small": 2.15, "json_answer:no": 1.81, "json_answer:yes": 1.78, # "json_answer:1": 1.01 # ... other categories omitted # } # }, # "DYNAMIC": { # "status": "Success", # "category_counts": { # "number_only": 48.02, "short_answer_other": 23.16, "yes_no": 16.92, "object_movement_sat": 7.87, # "json_answer:large": 0.62, "json_answer:small": 0.57 # ... other categories omitted # } # }, # "PERCEPTION": { # "status": "Success", # "category_counts": { # "number_only": 98.89, "short_answer_other": 0.27, "json_answer:no": 0.11, "json_answer:small": 0.1, # "json_answer:large": 0.1, "json_answer:yes": 0.1 # ... other categories omitted # } # }, # "REASONING": { # "status": "Success", # "category_counts": { # "point_prediction_prism": 20.0, "yes_no": 18.64, "multiple_choice_letter_spar7m": 13.11, # "simple_direction_sat": 6.93, "point_list_robospatial": 6.67, "long_descriptive_other": 6.07, # "short_answer_other": 5.15, "descriptive_sentence_spar7m": 2.51, "json_answer:small": 2.17, # "json_answer:large": 2.13, "json_answer:no": 1.78, "json_answer:yes": 1.71, "json_answer:1": 1.03, # "object_movement_sat": 0.94 # ... other categories omitted # } # }, # "_2D": { # "status": "Success", # "category_counts": { # "number_only": 56.27, "simple_direction_sat": 25.85, "yes_no": 4.43, "json_answer:no": 2.68, # "json_answer:yes": 2.6, "json_answer:1": 1.49, "json_answer:0": 1.2, "json_answer:small": 0.72, # "json_answer:large": 0.7, "json_answer:2": 0.58 # ... other categories omitted # } # }, # "_3D": { # "status": "Success", # "category_counts": { # "yes_no": 21.13, "point_prediction_prism": 20.91, "multiple_choice_letter_spar7m": 14.48, # "short_answer_other": 8.87, "point_list_robospatial": 6.97, "long_descriptive_other": 6.68, # "descriptive_sentence_spar7m": 2.62, "json_answer:large": 2.44, "json_answer:small": 2.44, # "object_movement_sat": 1.62 # ... other categories omitted # } # } # } # --- Data Preparation --- import json with open('__finetuning_data_analysis.json', 'r') as f: data = json.load(f) # Simplify 'json_answer:...' categories into one 'json_answer' simplified_data = {} all_categories = set() for dataset, info in data.items(): if info['status'] == 'Success': counts = info['category_counts'] new_counts = {} json_sum = 0 for cat, perc in counts.items(): if cat.startswith("json_answer:"): json_sum += perc else: new_counts[cat] = perc all_categories.add(cat) if json_sum > 0: new_counts['json_answer'] = round(json_sum, 2) all_categories.add('json_answer') simplified_data[dataset] = new_counts # Convert to DataFrame df = pd.DataFrame(simplified_data).fillna(0).T # Transpose for plotting # --- Select Top Categories and Group Others --- # Calculate total percentage for each category across all datasets category_totals = df.sum().sort_values(ascending=False) # Keep top N categories (adjust N as needed) top_n = 10 top_categories = category_totals.head(top_n).index.tolist() other_categories = category_totals.iloc[top_n:].index.tolist() # Group remaining categories into 'Other' if other_categories: df['Other'] = df[other_categories].sum(axis=1) df = df[top_categories + ['Other']] # Keep only top N + Other else: df = df[top_categories] # Keep only top N if no 'Other' needed # Ensure DataFrame columns are sorted for consistent legend order df = df[sorted(df.columns)] # --- Plotting --- plt.style.use('seaborn-v0_8-whitegrid') # Use a clean style fig, ax = plt.subplots(figsize=(14, 8)) # Define a color palette colors = sns.color_palette("tab20", n_colors=len(df.columns)) # Create the stacked bar chart df.plot(kind='bar', stacked=True, ax=ax, color=colors, width=0.8) # --- Formatting --- ax.set_xlabel("Dataset Category", fontsize=12, fontweight='bold') ax.set_ylabel("Percentage of Answer Formats (%)", fontsize=12, fontweight='bold') ax.set_title("Distribution of Answer Formats Across Fine-tuning Datasets", fontsize=16, fontweight='bold', pad=20) ax.tick_params(axis='x', rotation=45, labelsize=11) ax.tick_params(axis='y', labelsize=11) ax.set_ylim(0, 100) ax.yaxis.grid(True, linestyle='--', alpha=0.7) ax.xaxis.grid(False) # Remove vertical grid lines # Add percentage labels inside the bars (only for segments > 5%) # Labels removed for clarity in this version # --- Legend --- # Place legend outside the plot area ax.legend(title="Answer Format Categories", bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0., fontsize=10, title_fontsize=11) plt.tight_layout(rect=[0, 0, 0.85, 1]) # Adjust layout to make space for legend # --- Save the plot instead of showing it --- output_filename = "answer_format_distribution.png" plt.savefig(output_filename, dpi=300, bbox_inches='tight') # Use bbox_inches='tight' to include legend print(f"Plot saved as {output_filename}") # plt.show() # Replaced with savefig