import matplotlib.pyplot as plt import numpy as np import pandas as pd from typing import Optional, List from .data_loader import DataLoader from .utils import get_metric_choices, clean_metric_names class RadarPlotter: def __init__(self, data_loader: DataLoader): self.data_loader = data_loader # 获取所有具体指标(排除Average) all_metrics_with_markers = get_metric_choices() self.metrics = clean_metric_names([m for m in all_metrics_with_markers if m != "Average ⭐"]) def create_radar_chart(self, df: Optional[pd.DataFrame] = None, models: Optional[List[str]] = None) -> plt.Figure: if df is None or df.empty: df = self.data_loader.df_all.copy() if self.data_loader.df_all is not None else pd.DataFrame() if df.empty: fig, ax = plt.subplots(figsize=(6, 6), subplot_kw=dict(projection='polar')) ax.text(0.5, 0.5, "No data available", ha="center", va="center", transform=ax.transAxes) return fig # 限制显示的模型数量 if len(df) > 6: df = df.nlargest(6, "Average") # 使用的指标(轴) metrics = self.metrics valid_metrics = [m for m in metrics if m in df.columns] if not valid_metrics: fig, ax = plt.subplots(figsize=(6, 6), subplot_kw=dict(projection='polar')) ax.text(0.5, 0.5, "No valid metrics", ha="center", va="center", transform=ax.transAxes) return fig # 角度 angles = np.linspace(0, 2 * np.pi, len(valid_metrics), endpoint=False).tolist() angles += angles[:1] # 创建图形 fig, ax = plt.subplots(figsize=(6, 6), subplot_kw=dict(projection='polar')) # 颜色 colors = plt.cm.tab10(np.linspace(0, 1, len(df))) for idx, (_, row) in enumerate(df.iterrows()): values = [row.get(m, 0) for m in valid_metrics] values += values[:1] ax.plot(angles, values, 'o-', linewidth=2, label=row["Model"], color=colors[idx]) ax.fill(angles, values, alpha=0.1, color=colors[idx]) # 标签 ax.set_xticks(angles[:-1]) ax.set_xticklabels(valid_metrics, fontsize=8) ax.set_ylim(0, 1) ax.set_title("Performance Radar (8 metrics)", fontsize=12, fontweight="bold", pad=20) ax.legend(loc='upper right', bbox_to_anchor=(1.3, 1.0), fontsize=8) ax.grid(True, linestyle='--', alpha=0.5) plt.tight_layout() return fig