|
|
""" |
|
|
Leaderboard Styling Module |
|
|
|
|
|
Handles color gradients and visual styling for the leaderboard. |
|
|
""" |
|
|
|
|
|
import logging |
|
|
import html |
|
|
from typing import Dict, Tuple, List |
|
|
import pandas as pd |
|
|
from matplotlib.colors import LinearSegmentedColormap |
|
|
|
|
|
from ..core.columns import column_registry |
|
|
from .transformer import format_parameter_count |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class LeaderboardStyler: |
|
|
""" |
|
|
Applies visual styling to leaderboard DataFrames. |
|
|
|
|
|
Uses Excel-like Red-Yellow-Green color gradients for score columns. |
|
|
""" |
|
|
|
|
|
|
|
|
GRADIENT_COLORS = [ |
|
|
(0.9, 0.1, 0.2), |
|
|
(1.0, 1.0, 0.0), |
|
|
(0/255, 176/255, 80/255) |
|
|
] |
|
|
|
|
|
def __init__(self): |
|
|
self._colormap = LinearSegmentedColormap.from_list( |
|
|
"ExcelRedYellowGreen", |
|
|
self.GRADIENT_COLORS, |
|
|
N=256 |
|
|
) |
|
|
|
|
|
@staticmethod |
|
|
def rgb_to_hex(rgb: Tuple[float, float, float]) -> str: |
|
|
"""Convert RGB tuple (0-1 range) to hex color.""" |
|
|
r = int(rgb[0] * 255) |
|
|
g = int(rgb[1] * 255) |
|
|
b = int(rgb[2] * 255) |
|
|
return f"#{r:02x}{g:02x}{b:02x}" |
|
|
|
|
|
def get_color_for_value(self, value: float, min_val: float, max_val: float) -> str: |
|
|
"""Get hex color for a value within a range.""" |
|
|
if max_val == min_val: |
|
|
normalized = 0.5 |
|
|
else: |
|
|
normalized = (value - min_val) / (max_val - min_val) |
|
|
|
|
|
|
|
|
normalized = max(0, min(0.999, normalized)) |
|
|
|
|
|
rgba = self._colormap(normalized) |
|
|
return self.rgb_to_hex(rgba[:3]) |
|
|
|
|
|
def calculate_color_ranges(self, df: pd.DataFrame) -> Dict[str, Dict[str, float]]: |
|
|
"""Calculate min/max for each score column.""" |
|
|
ranges = {} |
|
|
|
|
|
for col_name in column_registry.score_columns: |
|
|
if col_name not in df.columns: |
|
|
continue |
|
|
|
|
|
numeric_values = pd.to_numeric(df[col_name], errors='coerce') |
|
|
if numeric_values.isna().all(): |
|
|
continue |
|
|
|
|
|
ranges[col_name] = { |
|
|
'min': numeric_values.min(), |
|
|
'max': numeric_values.max() |
|
|
} |
|
|
|
|
|
return ranges |
|
|
|
|
|
def apply_styling(self, df: pd.DataFrame) -> "pd.io.formats.style.Styler": |
|
|
""" |
|
|
Apply color styling to DataFrame. |
|
|
|
|
|
Returns a pandas Styler object that Gradio can render. |
|
|
""" |
|
|
if df.empty: |
|
|
return df.style |
|
|
|
|
|
df_copy = df.copy() |
|
|
|
|
|
|
|
|
for col in column_registry.score_columns: |
|
|
if col in df_copy.columns: |
|
|
df_copy[col] = df_copy[col].replace("N/A", pd.NA) |
|
|
df_copy[col] = pd.to_numeric(df_copy[col], errors='coerce') |
|
|
|
|
|
|
|
|
color_ranges = self.calculate_color_ranges(df_copy) |
|
|
|
|
|
|
|
|
def apply_gradient(val, col_name: str): |
|
|
if col_name not in color_ranges: |
|
|
return '' |
|
|
|
|
|
if pd.isna(val): |
|
|
return '' |
|
|
|
|
|
try: |
|
|
min_val = color_ranges[col_name]['min'] |
|
|
max_val = color_ranges[col_name]['max'] |
|
|
color_hex = self.get_color_for_value(float(val), min_val, max_val) |
|
|
return f'background-color: {color_hex}; text-align: center; font-weight: bold; color: #333;' |
|
|
except (ValueError, TypeError): |
|
|
return '' |
|
|
|
|
|
|
|
|
styler = df_copy.style |
|
|
|
|
|
for col in column_registry.score_columns: |
|
|
if col in df_copy.columns: |
|
|
styler = styler.map( |
|
|
lambda val, c=col: apply_gradient(val, c), |
|
|
subset=[col] |
|
|
) |
|
|
|
|
|
|
|
|
format_dict = {} |
|
|
for col_name in column_registry.numeric_columns: |
|
|
if col_name in df_copy.columns: |
|
|
col_def = column_registry.get(col_name) |
|
|
|
|
|
if col_name == "Parameters": |
|
|
format_dict[col_name] = format_parameter_count |
|
|
elif col_def and col_def.decimals == 0: |
|
|
format_dict[col_name] = '{:.0f}' |
|
|
elif col_def and col_def.decimals == 3: |
|
|
format_dict[col_name] = '{:.3f}' |
|
|
else: |
|
|
format_dict[col_name] = '{:.2f}' |
|
|
|
|
|
|
|
|
if "Model" in df_copy.columns: |
|
|
def _model_link_formatter(value: object) -> str: |
|
|
model_name = html.escape(str(value)) |
|
|
return ( |
|
|
f'<a href="https://huggingface.co/{model_name}" target="_blank" ' |
|
|
f'style="color: #2563eb; text-decoration: underline;">{model_name}</a>' |
|
|
) |
|
|
|
|
|
format_dict["Model"] = _model_link_formatter |
|
|
|
|
|
if format_dict: |
|
|
|
|
|
styler = styler.format(format_dict, na_rep='', escape=None) |
|
|
|
|
|
return styler |
|
|
|
|
|
def get_datatypes(self, columns: List[str]) -> List[str]: |
|
|
"""Get Gradio datatypes for columns.""" |
|
|
return column_registry.get_datatypes(columns) |
|
|
|
|
|
def get_column_widths(self, columns: List[str]) -> List[str]: |
|
|
"""Get column widths for columns.""" |
|
|
return column_registry.get_widths(columns) |
|
|
|