File size: 2,443 Bytes
4097ba4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import matplotlib.pyplot as plt
import pandas as pd
from typing import Optional
from .data_loader import DataLoader

class Plotter:
    def __init__(self, data_loader: DataLoader):
        self.data_loader = data_loader

    def create_comparison_plot(self, model_filter: str = "",

                              open_source_filter: str = "All",

                              year_filter: str = "All",

                              category_filter: str = "All",

                              metric: str = "Average",

                              sort_mode: str = "Descending (high → low)") -> plt.Figure:
        df = self.data_loader.filter_data(model_filter, open_source_filter,
                                         year_filter, category_filter)
        if df.empty or metric not in df.columns:
            fig, ax = plt.subplots(figsize=(10, 6))
            ax.text(0.5, 0.5, "No data available", ha="center", va="center", fontsize=14)
            ax.axis("off")
            return fig

        ascending = sort_mode.startswith("Ascending")
        df = df.sort_values(by=metric, ascending=ascending)
        if len(df) > 20:
            df = df.head(20)

        fig, ax = plt.subplots(figsize=(12, max(6, len(df) * 0.4)))
        colors = {
            "Text-Conditioned": "#3b82f6",
            "One-hot": "#10b981",
            "Intrinsics/Extrinsics": "#f59e0b"
        }
        bar_colors = [colors.get(cat, "#6b7280") for cat in df["Category"]]

        bars = ax.barh(df["Model"], df[metric], color=bar_colors, edgecolor="white", linewidth=0.5)
        for bar, val in zip(bars, df[metric]):
            width = bar.get_width()
            ax.text(width + 0.01, bar.get_y() + bar.get_height()/2,
                   f"{val:.4f}", ha="left", va="center", fontsize=9)

        ax.set_xlabel(metric, fontsize=12, fontweight="bold")
        ax.set_title(f"Model Comparison - {metric}", fontsize=14, fontweight="bold", pad=20)
        ax.set_xlim(0, df[metric].max() * 1.15)
        ax.grid(axis="x", alpha=0.3, linestyle="--")

        from matplotlib.patches import Patch
        legend_elements = [Patch(facecolor=color, label=cat)
                          for cat, color in colors.items()
                          if cat in df["Category"].values]
        ax.legend(handles=legend_elements, loc="lower right", title="Category")
        plt.tight_layout()
        return fig