Mizan / src /data /styler.py
nmmursit's picture
Refactor codebase structure
bc37111
"""
Leaderboard Styling Module
Handles color gradients and visual styling for the leaderboard.
"""
import logging
import html
from typing import Dict, Tuple, List
import pandas as pd
from matplotlib.colors import LinearSegmentedColormap
from ..core.columns import column_registry
from .transformer import format_parameter_count
logger = logging.getLogger(__name__)
class LeaderboardStyler:
"""
Applies visual styling to leaderboard DataFrames.
Uses Excel-like Red-Yellow-Green color gradients for score columns.
"""
# Excel-style color gradient: Red -> Yellow -> Green
GRADIENT_COLORS = [
(0.9, 0.1, 0.2), # Red (low scores)
(1.0, 1.0, 0.0), # Yellow (medium scores)
(0/255, 176/255, 80/255) # Excel Green (high scores)
]
def __init__(self):
self._colormap = LinearSegmentedColormap.from_list(
"ExcelRedYellowGreen",
self.GRADIENT_COLORS,
N=256
)
@staticmethod
def rgb_to_hex(rgb: Tuple[float, float, float]) -> str:
"""Convert RGB tuple (0-1 range) to hex color."""
r = int(rgb[0] * 255)
g = int(rgb[1] * 255)
b = int(rgb[2] * 255)
return f"#{r:02x}{g:02x}{b:02x}"
def get_color_for_value(self, value: float, min_val: float, max_val: float) -> str:
"""Get hex color for a value within a range."""
if max_val == min_val:
normalized = 0.5
else:
normalized = (value - min_val) / (max_val - min_val)
# Clamp to [0, 0.999] to avoid edge case at exactly 1.0
normalized = max(0, min(0.999, normalized))
rgba = self._colormap(normalized)
return self.rgb_to_hex(rgba[:3])
def calculate_color_ranges(self, df: pd.DataFrame) -> Dict[str, Dict[str, float]]:
"""Calculate min/max for each score column."""
ranges = {}
for col_name in column_registry.score_columns:
if col_name not in df.columns:
continue
numeric_values = pd.to_numeric(df[col_name], errors='coerce')
if numeric_values.isna().all():
continue
ranges[col_name] = {
'min': numeric_values.min(),
'max': numeric_values.max()
}
return ranges
def apply_styling(self, df: pd.DataFrame) -> "pd.io.formats.style.Styler":
"""
Apply color styling to DataFrame.
Returns a pandas Styler object that Gradio can render.
"""
if df.empty:
return df.style
df_copy = df.copy()
# Convert "N/A" to NaN for proper formatting
for col in column_registry.score_columns:
if col in df_copy.columns:
df_copy[col] = df_copy[col].replace("N/A", pd.NA)
df_copy[col] = pd.to_numeric(df_copy[col], errors='coerce')
# Calculate color ranges
color_ranges = self.calculate_color_ranges(df_copy)
# Create style function
def apply_gradient(val, col_name: str):
if col_name not in color_ranges:
return ''
if pd.isna(val):
return ''
try:
min_val = color_ranges[col_name]['min']
max_val = color_ranges[col_name]['max']
color_hex = self.get_color_for_value(float(val), min_val, max_val)
return f'background-color: {color_hex}; text-align: center; font-weight: bold; color: #333;'
except (ValueError, TypeError):
return ''
# Apply styling
styler = df_copy.style
for col in column_registry.score_columns:
if col in df_copy.columns:
styler = styler.map(
lambda val, c=col: apply_gradient(val, c),
subset=[col]
)
# Format numeric columns
format_dict = {}
for col_name in column_registry.numeric_columns:
if col_name in df_copy.columns:
col_def = column_registry.get(col_name)
# Special handling for Parameters column - use human-readable format
if col_name == "Parameters":
format_dict[col_name] = format_parameter_count
elif col_def and col_def.decimals == 0:
format_dict[col_name] = '{:.0f}'
elif col_def and col_def.decimals == 3:
format_dict[col_name] = '{:.3f}'
else:
format_dict[col_name] = '{:.2f}'
# Format model column as hyperlink without mutating the underlying data
if "Model" in df_copy.columns:
def _model_link_formatter(value: object) -> str:
model_name = html.escape(str(value))
return (
f'<a href="https://huggingface.co/{model_name}" target="_blank" '
f'style="color: #2563eb; text-decoration: underline;">{model_name}</a>'
)
format_dict["Model"] = _model_link_formatter
if format_dict:
# Don't replace NA values - let them display as they are in the CSV
styler = styler.format(format_dict, na_rep='', escape=None)
return styler
def get_datatypes(self, columns: List[str]) -> List[str]:
"""Get Gradio datatypes for columns."""
return column_registry.get_datatypes(columns)
def get_column_widths(self, columns: List[str]) -> List[str]:
"""Get column widths for columns."""
return column_registry.get_widths(columns)