Spaces:
Running
Running
File size: 2,443 Bytes
4097ba4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 | import matplotlib.pyplot as plt
import pandas as pd
from typing import Optional
from .data_loader import DataLoader
class Plotter:
def __init__(self, data_loader: DataLoader):
self.data_loader = data_loader
def create_comparison_plot(self, model_filter: str = "",
open_source_filter: str = "All",
year_filter: str = "All",
category_filter: str = "All",
metric: str = "Average",
sort_mode: str = "Descending (high → low)") -> plt.Figure:
df = self.data_loader.filter_data(model_filter, open_source_filter,
year_filter, category_filter)
if df.empty or metric not in df.columns:
fig, ax = plt.subplots(figsize=(10, 6))
ax.text(0.5, 0.5, "No data available", ha="center", va="center", fontsize=14)
ax.axis("off")
return fig
ascending = sort_mode.startswith("Ascending")
df = df.sort_values(by=metric, ascending=ascending)
if len(df) > 20:
df = df.head(20)
fig, ax = plt.subplots(figsize=(12, max(6, len(df) * 0.4)))
colors = {
"Text-Conditioned": "#3b82f6",
"One-hot": "#10b981",
"Intrinsics/Extrinsics": "#f59e0b"
}
bar_colors = [colors.get(cat, "#6b7280") for cat in df["Category"]]
bars = ax.barh(df["Model"], df[metric], color=bar_colors, edgecolor="white", linewidth=0.5)
for bar, val in zip(bars, df[metric]):
width = bar.get_width()
ax.text(width + 0.01, bar.get_y() + bar.get_height()/2,
f"{val:.4f}", ha="left", va="center", fontsize=9)
ax.set_xlabel(metric, fontsize=12, fontweight="bold")
ax.set_title(f"Model Comparison - {metric}", fontsize=14, fontweight="bold", pad=20)
ax.set_xlim(0, df[metric].max() * 1.15)
ax.grid(axis="x", alpha=0.3, linestyle="--")
from matplotlib.patches import Patch
legend_elements = [Patch(facecolor=color, label=cat)
for cat, color in colors.items()
if cat in df["Category"].values]
ax.legend(handles=legend_elements, loc="lower right", title="Category")
plt.tight_layout()
return fig |