Fine-tuning-data / analyze_data_vis.py
ch-min's picture
Upload folder using huggingface_hub
849ca03 verified
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import seaborn as sns
import os # Import os for path joining
# Provided data (shortened for brevity, full data used in execution)
# data = {
# "SYNTHETIC": {
# "status": "Success",
# "category_counts": {
# "point_prediction_prism": 32.85, "simple_direction_sat": 11.1, "yes_no": 8.51, "short_answer_other": 8.31,
# "number_only": 5.37, "json_answer:small": 3.15, "json_answer:large": 3.14, "json_answer:no": 2.68,
# "json_answer:yes": 2.6, "multiple_choice_letter_spar7m": 1.85, "object_movement_sat": 1.51,
# "json_answer:1": 1.49, "json_answer:0": 1.2, "json_answer:2": 0.95, # ... other json_answers omitted
# "long_descriptive_other": 0.82
# # ... other categories omitted for brevity in this display
# }
# },
# "REAL": {
# "status": "Success",
# "category_counts": {
# "yes_no": 33.33, "multiple_choice_letter_spar7m": 22.7, "point_list_robospatial": 16.67,
# "number_only": 12.29, "long_descriptive_other": 10.35, "descriptive_sentence_spar7m": 4.66
# }
# },
# "STATIC": {
# "status": "Success",
# "category_counts": {
# "point_prediction_prism": 20.0, "yes_no": 17.12, "multiple_choice_letter_spar7m": 13.26,
# "simple_direction_sat": 7.77, "point_list_robospatial": 6.67, "long_descriptive_other": 6.14,
# "number_only": 3.87, "short_answer_other": 2.68, "descriptive_sentence_spar7m": 2.51,
# "json_answer:large": 2.15, "json_answer:small": 2.15, "json_answer:no": 1.81, "json_answer:yes": 1.78,
# "json_answer:1": 1.01 # ... other categories omitted
# }
# },
# "DYNAMIC": {
# "status": "Success",
# "category_counts": {
# "number_only": 48.02, "short_answer_other": 23.16, "yes_no": 16.92, "object_movement_sat": 7.87,
# "json_answer:large": 0.62, "json_answer:small": 0.57 # ... other categories omitted
# }
# },
# "PERCEPTION": {
# "status": "Success",
# "category_counts": {
# "number_only": 98.89, "short_answer_other": 0.27, "json_answer:no": 0.11, "json_answer:small": 0.1,
# "json_answer:large": 0.1, "json_answer:yes": 0.1 # ... other categories omitted
# }
# },
# "REASONING": {
# "status": "Success",
# "category_counts": {
# "point_prediction_prism": 20.0, "yes_no": 18.64, "multiple_choice_letter_spar7m": 13.11,
# "simple_direction_sat": 6.93, "point_list_robospatial": 6.67, "long_descriptive_other": 6.07,
# "short_answer_other": 5.15, "descriptive_sentence_spar7m": 2.51, "json_answer:small": 2.17,
# "json_answer:large": 2.13, "json_answer:no": 1.78, "json_answer:yes": 1.71, "json_answer:1": 1.03,
# "object_movement_sat": 0.94 # ... other categories omitted
# }
# },
# "_2D": {
# "status": "Success",
# "category_counts": {
# "number_only": 56.27, "simple_direction_sat": 25.85, "yes_no": 4.43, "json_answer:no": 2.68,
# "json_answer:yes": 2.6, "json_answer:1": 1.49, "json_answer:0": 1.2, "json_answer:small": 0.72,
# "json_answer:large": 0.7, "json_answer:2": 0.58 # ... other categories omitted
# }
# },
# "_3D": {
# "status": "Success",
# "category_counts": {
# "yes_no": 21.13, "point_prediction_prism": 20.91, "multiple_choice_letter_spar7m": 14.48,
# "short_answer_other": 8.87, "point_list_robospatial": 6.97, "long_descriptive_other": 6.68,
# "descriptive_sentence_spar7m": 2.62, "json_answer:large": 2.44, "json_answer:small": 2.44,
# "object_movement_sat": 1.62 # ... other categories omitted
# }
# }
# }
# --- Data Preparation ---
import json
with open('__finetuning_data_analysis.json', 'r') as f:
data = json.load(f)
# Simplify 'json_answer:...' categories into one 'json_answer'
simplified_data = {}
all_categories = set()
for dataset, info in data.items():
if info['status'] == 'Success':
counts = info['category_counts']
new_counts = {}
json_sum = 0
for cat, perc in counts.items():
if cat.startswith("json_answer:"):
json_sum += perc
else:
new_counts[cat] = perc
all_categories.add(cat)
if json_sum > 0:
new_counts['json_answer'] = round(json_sum, 2)
all_categories.add('json_answer')
simplified_data[dataset] = new_counts
# Convert to DataFrame
df = pd.DataFrame(simplified_data).fillna(0).T # Transpose for plotting
# --- Select Top Categories and Group Others ---
# Calculate total percentage for each category across all datasets
category_totals = df.sum().sort_values(ascending=False)
# Keep top N categories (adjust N as needed)
top_n = 10
top_categories = category_totals.head(top_n).index.tolist()
other_categories = category_totals.iloc[top_n:].index.tolist()
# Group remaining categories into 'Other'
if other_categories:
df['Other'] = df[other_categories].sum(axis=1)
df = df[top_categories + ['Other']] # Keep only top N + Other
else:
df = df[top_categories] # Keep only top N if no 'Other' needed
# Ensure DataFrame columns are sorted for consistent legend order
df = df[sorted(df.columns)]
# --- Plotting ---
plt.style.use('seaborn-v0_8-whitegrid') # Use a clean style
fig, ax = plt.subplots(figsize=(14, 8))
# Define a color palette
colors = sns.color_palette("tab20", n_colors=len(df.columns))
# Create the stacked bar chart
df.plot(kind='bar', stacked=True, ax=ax, color=colors, width=0.8)
# --- Formatting ---
ax.set_xlabel("Dataset Category", fontsize=12, fontweight='bold')
ax.set_ylabel("Percentage of Answer Formats (%)", fontsize=12, fontweight='bold')
ax.set_title("Distribution of Answer Formats Across Fine-tuning Datasets", fontsize=16, fontweight='bold', pad=20)
ax.tick_params(axis='x', rotation=45, labelsize=11)
ax.tick_params(axis='y', labelsize=11)
ax.set_ylim(0, 100)
ax.yaxis.grid(True, linestyle='--', alpha=0.7)
ax.xaxis.grid(False) # Remove vertical grid lines
# Add percentage labels inside the bars (only for segments > 5%)
# Labels removed for clarity in this version
# --- Legend ---
# Place legend outside the plot area
ax.legend(title="Answer Format Categories", bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0., fontsize=10, title_fontsize=11)
plt.tight_layout(rect=[0, 0, 0.85, 1]) # Adjust layout to make space for legend
# --- Save the plot instead of showing it ---
output_filename = "answer_format_distribution.png"
plt.savefig(output_filename, dpi=300, bbox_inches='tight') # Use bbox_inches='tight' to include legend
print(f"Plot saved as {output_filename}")
# plt.show() # Replaced with savefig