""" Leaderboard Styling Module Handles color gradients and visual styling for the leaderboard. """ import logging import html from typing import Dict, Tuple, List import pandas as pd from matplotlib.colors import LinearSegmentedColormap from ..core.columns import column_registry from .transformer import format_parameter_count logger = logging.getLogger(__name__) class LeaderboardStyler: """ Applies visual styling to leaderboard DataFrames. Uses Excel-like Red-Yellow-Green color gradients for score columns. """ # Excel-style color gradient: Red -> Yellow -> Green GRADIENT_COLORS = [ (0.9, 0.1, 0.2), # Red (low scores) (1.0, 1.0, 0.0), # Yellow (medium scores) (0/255, 176/255, 80/255) # Excel Green (high scores) ] def __init__(self): self._colormap = LinearSegmentedColormap.from_list( "ExcelRedYellowGreen", self.GRADIENT_COLORS, N=256 ) @staticmethod def rgb_to_hex(rgb: Tuple[float, float, float]) -> str: """Convert RGB tuple (0-1 range) to hex color.""" r = int(rgb[0] * 255) g = int(rgb[1] * 255) b = int(rgb[2] * 255) return f"#{r:02x}{g:02x}{b:02x}" def get_color_for_value(self, value: float, min_val: float, max_val: float) -> str: """Get hex color for a value within a range.""" if max_val == min_val: normalized = 0.5 else: normalized = (value - min_val) / (max_val - min_val) # Clamp to [0, 0.999] to avoid edge case at exactly 1.0 normalized = max(0, min(0.999, normalized)) rgba = self._colormap(normalized) return self.rgb_to_hex(rgba[:3]) def calculate_color_ranges(self, df: pd.DataFrame) -> Dict[str, Dict[str, float]]: """Calculate min/max for each score column.""" ranges = {} for col_name in column_registry.score_columns: if col_name not in df.columns: continue numeric_values = pd.to_numeric(df[col_name], errors='coerce') if numeric_values.isna().all(): continue ranges[col_name] = { 'min': numeric_values.min(), 'max': numeric_values.max() } return ranges def apply_styling(self, df: pd.DataFrame) -> "pd.io.formats.style.Styler": """ Apply color styling to DataFrame. Returns a pandas Styler object that Gradio can render. """ if df.empty: return df.style df_copy = df.copy() # Convert "N/A" to NaN for proper formatting for col in column_registry.score_columns: if col in df_copy.columns: df_copy[col] = df_copy[col].replace("N/A", pd.NA) df_copy[col] = pd.to_numeric(df_copy[col], errors='coerce') # Calculate color ranges color_ranges = self.calculate_color_ranges(df_copy) # Create style function def apply_gradient(val, col_name: str): if col_name not in color_ranges: return '' if pd.isna(val): return '' try: min_val = color_ranges[col_name]['min'] max_val = color_ranges[col_name]['max'] color_hex = self.get_color_for_value(float(val), min_val, max_val) return f'background-color: {color_hex}; text-align: center; font-weight: bold; color: #333;' except (ValueError, TypeError): return '' # Apply styling styler = df_copy.style for col in column_registry.score_columns: if col in df_copy.columns: styler = styler.map( lambda val, c=col: apply_gradient(val, c), subset=[col] ) # Format numeric columns format_dict = {} for col_name in column_registry.numeric_columns: if col_name in df_copy.columns: col_def = column_registry.get(col_name) # Special handling for Parameters column - use human-readable format if col_name == "Parameters": format_dict[col_name] = format_parameter_count elif col_def and col_def.decimals == 0: format_dict[col_name] = '{:.0f}' elif col_def and col_def.decimals == 3: format_dict[col_name] = '{:.3f}' else: format_dict[col_name] = '{:.2f}' # Format model column as hyperlink without mutating the underlying data if "Model" in df_copy.columns: def _model_link_formatter(value: object) -> str: model_name = html.escape(str(value)) return ( f'{model_name}' ) format_dict["Model"] = _model_link_formatter if format_dict: # Don't replace NA values - let them display as they are in the CSV styler = styler.format(format_dict, na_rep='', escape=None) return styler def get_datatypes(self, columns: List[str]) -> List[str]: """Get Gradio datatypes for columns.""" return column_registry.get_datatypes(columns) def get_column_widths(self, columns: List[str]) -> List[str]: """Get column widths for columns.""" return column_registry.get_widths(columns)