Spaces:
Running
Running
| import matplotlib.pyplot as plt | |
| import pandas as pd | |
| from typing import Optional | |
| from .data_loader import DataLoader | |
| class Plotter: | |
| def __init__(self, data_loader: DataLoader): | |
| self.data_loader = data_loader | |
| def create_comparison_plot(self, model_filter: str = "", | |
| open_source_filter: str = "All", | |
| year_filter: str = "All", | |
| category_filter: str = "All", | |
| metric: str = "Average", | |
| sort_mode: str = "Descending (high → low)") -> plt.Figure: | |
| df = self.data_loader.filter_data(model_filter, open_source_filter, | |
| year_filter, category_filter) | |
| if df.empty or metric not in df.columns: | |
| fig, ax = plt.subplots(figsize=(10, 6)) | |
| ax.text(0.5, 0.5, "No data available", ha="center", va="center", fontsize=14) | |
| ax.axis("off") | |
| return fig | |
| ascending = sort_mode.startswith("Ascending") | |
| df = df.sort_values(by=metric, ascending=ascending) | |
| if len(df) > 20: | |
| df = df.head(20) | |
| fig, ax = plt.subplots(figsize=(12, max(6, len(df) * 0.4))) | |
| colors = { | |
| "Text-Conditioned": "#3b82f6", | |
| "One-hot": "#10b981", | |
| "Intrinsics/Extrinsics": "#f59e0b" | |
| } | |
| bar_colors = [colors.get(cat, "#6b7280") for cat in df["Category"]] | |
| bars = ax.barh(df["Model"], df[metric], color=bar_colors, edgecolor="white", linewidth=0.5) | |
| for bar, val in zip(bars, df[metric]): | |
| width = bar.get_width() | |
| ax.text(width + 0.01, bar.get_y() + bar.get_height()/2, | |
| f"{val:.4f}", ha="left", va="center", fontsize=9) | |
| ax.set_xlabel(metric, fontsize=12, fontweight="bold") | |
| ax.set_title(f"Model Comparison - {metric}", fontsize=14, fontweight="bold", pad=20) | |
| ax.set_xlim(0, df[metric].max() * 1.15) | |
| ax.grid(axis="x", alpha=0.3, linestyle="--") | |
| from matplotlib.patches import Patch | |
| legend_elements = [Patch(facecolor=color, label=cat) | |
| for cat, color in colors.items() | |
| if cat in df["Category"].values] | |
| ax.legend(handles=legend_elements, loc="lower right", title="Category") | |
| plt.tight_layout() | |
| return fig |