import matplotlib.pyplot as plt import pandas as pd from typing import Optional from .data_loader import DataLoader class Plotter: def __init__(self, data_loader: DataLoader): self.data_loader = data_loader def create_comparison_plot(self, model_filter: str = "", open_source_filter: str = "All", year_filter: str = "All", category_filter: str = "All", metric: str = "Average", sort_mode: str = "Descending (high → low)") -> plt.Figure: df = self.data_loader.filter_data(model_filter, open_source_filter, year_filter, category_filter) if df.empty or metric not in df.columns: fig, ax = plt.subplots(figsize=(10, 6)) ax.text(0.5, 0.5, "No data available", ha="center", va="center", fontsize=14) ax.axis("off") return fig ascending = sort_mode.startswith("Ascending") df = df.sort_values(by=metric, ascending=ascending) if len(df) > 20: df = df.head(20) fig, ax = plt.subplots(figsize=(12, max(6, len(df) * 0.4))) colors = { "Text-Conditioned": "#3b82f6", "One-hot": "#10b981", "Intrinsics/Extrinsics": "#f59e0b" } bar_colors = [colors.get(cat, "#6b7280") for cat in df["Category"]] bars = ax.barh(df["Model"], df[metric], color=bar_colors, edgecolor="white", linewidth=0.5) for bar, val in zip(bars, df[metric]): width = bar.get_width() ax.text(width + 0.01, bar.get_y() + bar.get_height()/2, f"{val:.4f}", ha="left", va="center", fontsize=9) ax.set_xlabel(metric, fontsize=12, fontweight="bold") ax.set_title(f"Model Comparison - {metric}", fontsize=14, fontweight="bold", pad=20) ax.set_xlim(0, df[metric].max() * 1.15) ax.grid(axis="x", alpha=0.3, linestyle="--") from matplotlib.patches import Patch legend_elements = [Patch(facecolor=color, label=cat) for cat, color in colors.items() if cat in df["Category"].values] ax.legend(handles=legend_elements, loc="lower right", title="Category") plt.tight_layout() return fig