Spaces:
Running
Running
File size: 2,606 Bytes
4097ba4 37b185d 4097ba4 37b185d 4097ba4 37b185d 4097ba4 37b185d 4097ba4 37b185d 4097ba4 37b185d 4097ba4 37b185d 4097ba4 37b185d 4097ba4 37b185d 4097ba4 37b185d 4097ba4 37b185d 4097ba4 37b185d 4097ba4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 | import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from typing import Optional, List
from .data_loader import DataLoader
from .utils import get_metric_choices, clean_metric_names
class RadarPlotter:
def __init__(self, data_loader: DataLoader):
self.data_loader = data_loader
# 获取所有具体指标(排除Average)
all_metrics_with_markers = get_metric_choices()
self.metrics = clean_metric_names([m for m in all_metrics_with_markers if m != "Average ⭐"])
def create_radar_chart(self, df: Optional[pd.DataFrame] = None,
models: Optional[List[str]] = None) -> plt.Figure:
if df is None or df.empty:
df = self.data_loader.df_all.copy() if self.data_loader.df_all is not None else pd.DataFrame()
if df.empty:
fig, ax = plt.subplots(figsize=(6, 6), subplot_kw=dict(projection='polar'))
ax.text(0.5, 0.5, "No data available", ha="center", va="center", transform=ax.transAxes)
return fig
# 限制显示的模型数量
if len(df) > 6:
df = df.nlargest(6, "Average")
# 使用的指标(轴)
metrics = self.metrics
valid_metrics = [m for m in metrics if m in df.columns]
if not valid_metrics:
fig, ax = plt.subplots(figsize=(6, 6), subplot_kw=dict(projection='polar'))
ax.text(0.5, 0.5, "No valid metrics", ha="center", va="center", transform=ax.transAxes)
return fig
# 角度
angles = np.linspace(0, 2 * np.pi, len(valid_metrics), endpoint=False).tolist()
angles += angles[:1]
# 创建图形
fig, ax = plt.subplots(figsize=(6, 6), subplot_kw=dict(projection='polar'))
# 颜色
colors = plt.cm.tab10(np.linspace(0, 1, len(df)))
for idx, (_, row) in enumerate(df.iterrows()):
values = [row.get(m, 0) for m in valid_metrics]
values += values[:1]
ax.plot(angles, values, 'o-', linewidth=2, label=row["Model"], color=colors[idx])
ax.fill(angles, values, alpha=0.1, color=colors[idx])
# 标签
ax.set_xticks(angles[:-1])
ax.set_xticklabels(valid_metrics, fontsize=8)
ax.set_ylim(0, 1)
ax.set_title("Performance Radar (8 metrics)", fontsize=12, fontweight="bold", pad=20)
ax.legend(loc='upper right', bbox_to_anchor=(1.3, 1.0), fontsize=8)
ax.grid(True, linestyle='--', alpha=0.5)
plt.tight_layout()
return fig |