Spaces:
Running
Running
| """ | |
| Visualization module for AMA-Bench leaderboard | |
| Adapted from lmGAME_bench patterns with AMA-specific customizations | |
| """ | |
| import plotly.graph_objects as go | |
| import numpy as np | |
| import pandas as pd | |
| import json | |
| import os | |
| from typing import Dict, List, Optional, Tuple | |
| # Constants | |
| METRICS = ["Recall", "Causal Inference", "State Updating", "State Abstraction"] | |
| ALL_METRICS = METRICS + ["Average"] | |
| def load_model_colors(filepath: str = "assets/model_colors.json") -> Dict[str, str]: | |
| """ | |
| Load color scheme for models and methods from JSON file. | |
| Args: | |
| filepath: Path to color configuration JSON | |
| Returns: | |
| Dictionary mapping model/method names to hex colors | |
| """ | |
| try: | |
| with open(filepath, 'r', encoding='utf-8') as f: | |
| color_data = json.load(f) | |
| # Merge models and methods into single dictionary | |
| colors = {} | |
| if 'models' in color_data: | |
| colors.update(color_data['models']) | |
| if 'methods' in color_data: | |
| colors.update(color_data['methods']) | |
| # Store fallback color | |
| fallback = color_data.get('fallback', '#808080') | |
| return colors, fallback | |
| except Exception as e: | |
| print(f"Warning: Could not load colors from {filepath}: {e}") | |
| return {}, '#808080' | |
| def normalize_scores(values: List[float], mean: float, std: float) -> List[float]: | |
| """ | |
| Normalize scores using z-score and scale to 0-100 range. | |
| Adapted from lmGAME_bench's normalize_values() function. | |
| Args: | |
| values: List of accuracy values (0-1 range) | |
| mean: Mean value for normalization | |
| std: Standard deviation for normalization | |
| Returns: | |
| List of normalized scores (0-100 range) | |
| Formula: | |
| z_score = (value - mean) / std | |
| normalized = clamp((z_score * 30) + 35, 0, 100) | |
| """ | |
| # Handle zero std case (all values are the same) | |
| if std < 0.05: # Minimum std threshold to prevent extreme values | |
| std = 0.05 | |
| normalized = [] | |
| for v in values: | |
| z_score = (v - mean) / std | |
| scaled = (z_score * 30) + 35 | |
| clamped = max(0, min(100, scaled)) | |
| normalized.append(clamped) | |
| return normalized | |
| def filter_by_category(data: Dict, category: str) -> Dict: | |
| """ | |
| Filter method data by category. | |
| Args: | |
| data: Full dataset with entries | |
| category: "All", "RAG", or "Agent Memory" | |
| Returns: | |
| Filtered data dictionary | |
| """ | |
| if category == "All": | |
| return data | |
| filtered_data = data.copy() | |
| filtered_data['entries'] = [ | |
| entry for entry in data['entries'] | |
| if entry.get('category') == category | |
| ] | |
| return filtered_data | |
| def prepare_dataframe_for_visualization( | |
| data: Dict, | |
| top_n: Optional[int] = None, | |
| category_filter: str = "All", | |
| selected_metrics: Optional[List[str]] = None | |
| ) -> pd.DataFrame: | |
| """ | |
| Build DataFrame with both raw and normalized scores. | |
| Args: | |
| data: Raw data from model_data.json or method_data.json | |
| top_n: Number of top entries to include (None = all) | |
| category_filter: "All", "RAG", or "Agent Memory" (for methods only) | |
| selected_metrics: List of metrics to include (None = all) | |
| Returns: | |
| DataFrame with columns: | |
| - Method/Model (name) | |
| - Category (if applicable) | |
| - {Metric} (raw accuracy 0-1) for each metric | |
| - norm_{Metric} (normalized 0-100) for each metric | |
| - Avg Normalized Score (mean of normalized scores) | |
| """ | |
| # Filter by category first | |
| if category_filter != "All": | |
| data = filter_by_category(data, category_filter) | |
| if not data['entries']: | |
| # Return empty DataFrame if no entries | |
| return pd.DataFrame() | |
| # Use all metrics if none specified | |
| if selected_metrics is None: | |
| selected_metrics = METRICS | |
| # Build basic DataFrame | |
| rows = [] | |
| for entry in data['entries']: | |
| row = { | |
| 'Name': entry['method'], | |
| } | |
| # Add category if present | |
| if entry.get('category') is not None: | |
| row['Category'] = entry['category'] | |
| # Add raw scores | |
| for metric in selected_metrics: | |
| score_data = entry['scores'].get(metric, {}) | |
| row[metric] = score_data.get('accuracy', 0.0) | |
| # Add average | |
| row['Average'] = entry['scores'].get('Average', {}).get('accuracy', 0.0) | |
| rows.append(row) | |
| df = pd.DataFrame(rows) | |
| # Sort by average accuracy (descending) | |
| df = df.sort_values(by='Average', ascending=False) | |
| # Calculate normalization parameters from FULL dataset (before limiting) | |
| norm_params = {} | |
| for metric in selected_metrics: | |
| values = df[metric].values | |
| mean = values.mean() | |
| std = values.std() | |
| norm_params[metric] = (mean, std) | |
| # Apply top_n limit if specified | |
| if top_n is not None and top_n > 0: | |
| df = df.head(top_n) | |
| # Add normalized scores | |
| for metric in selected_metrics: | |
| mean, std = norm_params[metric] | |
| values = df[metric].values | |
| df[f'norm_{metric}'] = normalize_scores(values.tolist(), mean, std) | |
| # Calculate average normalized score | |
| norm_cols = [f'norm_{metric}' for metric in selected_metrics] | |
| df['Avg Normalized Score'] = df[norm_cols].mean(axis=1) | |
| # Reset index | |
| df = df.reset_index(drop=True) | |
| return df | |
| def hex_to_rgba(hex_color: str, alpha: float = 0.2) -> str: | |
| """ | |
| Convert hex color to RGBA with specified alpha. | |
| Args: | |
| hex_color: Hex color code (e.g., "#FF0000") | |
| alpha: Alpha value (0-1) | |
| Returns: | |
| RGBA color string | |
| """ | |
| hex_color = hex_color.lstrip('#') | |
| r = int(hex_color[0:2], 16) | |
| g = int(hex_color[2:4], 16) | |
| b = int(hex_color[4:6], 16) | |
| return f'rgba({r}, {g}, {b}, {alpha})' | |
| def create_radar_chart( | |
| df: pd.DataFrame, | |
| selected_metrics: List[str], | |
| title: str = "Performance Across Metrics", | |
| color_map: Optional[Dict[str, str]] = None | |
| ) -> go.Figure: | |
| """ | |
| Create radar chart with normalized scores. | |
| Adapted from lmGAME_bench's create_single_radar_chart(). | |
| Args: | |
| df: DataFrame from prepare_dataframe_for_visualization() | |
| selected_metrics: List of metric names to include as axes | |
| title: Chart title | |
| color_map: Dictionary mapping names to colors | |
| Returns: | |
| Plotly Figure with radar chart | |
| Features: | |
| - Each axis = one metric | |
| - Each trace = one model/method | |
| - Range: 0-100 (normalized) | |
| - Interactive legend (click to isolate, double-click to toggle) | |
| """ | |
| if df.empty: | |
| fig = go.Figure() | |
| fig.update_layout(title="No data available") | |
| return fig | |
| # Load colors if not provided | |
| if color_map is None: | |
| color_map, fallback_color = load_model_colors() | |
| else: | |
| fallback_color = '#808080' | |
| # Check if we have normalized columns | |
| norm_cols = [f'norm_{metric}' for metric in selected_metrics] | |
| if not all(col in df.columns for col in norm_cols): | |
| fig = go.Figure() | |
| fig.update_layout(title="Missing normalized data") | |
| return fig | |
| fig = go.Figure() | |
| # Add trace for each model/method | |
| for _, row in df.iterrows(): | |
| name = row['Name'] | |
| # Get normalized values for selected metrics | |
| r = [row[f'norm_{metric}'] for metric in selected_metrics] | |
| # Get color | |
| color = color_map.get(name, fallback_color) | |
| fillcolor = hex_to_rgba(color, 0.2) | |
| # Add trace | |
| fig.add_trace(go.Scatterpolar( | |
| r=r + [r[0]], # Close the polygon | |
| theta=selected_metrics + [selected_metrics[0]], | |
| mode='lines+markers', | |
| fill='toself', | |
| name=name.lower(), # Lowercase for legend | |
| line=dict(color=color, width=2), | |
| marker=dict(color=color, size=6), | |
| fillcolor=fillcolor, | |
| opacity=0.7, | |
| hovertemplate='<b>%{fullData.name}</b><br>%{theta}: %{r:.1f}<extra></extra>' | |
| )) | |
| # Update layout | |
| fig.update_layout( | |
| title=dict( | |
| text=title, | |
| x=0.5, | |
| xanchor='center', | |
| font=dict(size=18) | |
| ), | |
| polar=dict( | |
| radialaxis=dict( | |
| visible=True, | |
| range=[0, 100], | |
| tickfont=dict(size=11), | |
| gridcolor='lightgray', | |
| gridwidth=1 | |
| ), | |
| angularaxis=dict( | |
| tickfont=dict(size=12, weight='bold') | |
| ) | |
| ), | |
| legend=dict( | |
| font=dict(size=11), | |
| title=dict(text="Models/Methods 💡", font=dict(size=12)), | |
| itemsizing='trace', | |
| x=1.05, | |
| y=1, | |
| xanchor='left', | |
| yanchor='top', | |
| bgcolor='rgba(255,255,255,0.6)', | |
| bordercolor='gray', | |
| borderwidth=1, | |
| itemclick="toggleothers", | |
| itemdoubleclick="toggle" | |
| ), | |
| height=550, | |
| margin=dict(l=80, r=200, t=80, b=80) | |
| ) | |
| return fig | |
| def create_group_bar_chart( | |
| df: pd.DataFrame, | |
| selected_metrics: List[str], | |
| top_n: int = 5, | |
| color_map: Optional[Dict[str, str]] = None | |
| ) -> go.Figure: | |
| """ | |
| Create grouped bar chart showing top N performers per metric. | |
| Adapted from lmGAME_bench's create_group_bar_chart(). | |
| Args: | |
| df: DataFrame with normalized scores | |
| selected_metrics: List of metrics to display | |
| top_n: Number of top performers to show per metric | |
| color_map: Dictionary mapping names to colors | |
| Returns: | |
| Plotly Figure with grouped bar chart | |
| Structure: | |
| - X-axis: Metrics with rank positions (e.g., "Recall #1", "Recall #2") | |
| - Y-axis: Normalized score (0-100) | |
| - Bars: Grouped by model/method | |
| """ | |
| if df.empty: | |
| fig = go.Figure() | |
| fig.update_layout(title="No data available") | |
| return fig | |
| # Load colors if not provided | |
| if color_map is None: | |
| color_map, fallback_color = load_model_colors() | |
| else: | |
| fallback_color = '#808080' | |
| # Check for normalized columns | |
| norm_cols = [f'norm_{metric}' for metric in selected_metrics] | |
| if not all(col in df.columns for col in norm_cols): | |
| fig = go.Figure() | |
| fig.update_layout(title="Missing normalized data") | |
| return fig | |
| # Build x-axis categories and data structure | |
| all_x_categories = [] | |
| all_names = set() | |
| metric_rankings = {} | |
| for metric in selected_metrics: | |
| norm_col = f'norm_{metric}' | |
| # Get top N for this metric | |
| metric_df = df[df[norm_col].notna()].copy() | |
| metric_df = metric_df.sort_values(by=norm_col, ascending=False).head(top_n) | |
| metric_rankings[metric] = [] | |
| for rank, (_, row) in enumerate(metric_df.iterrows(), 1): | |
| name = row['Name'] | |
| score = row[norm_col] | |
| x_category = f"{metric}<br>#{rank}" | |
| metric_rankings[metric].append({ | |
| 'name': name, | |
| 'score': score, | |
| 'x_category': x_category, | |
| 'rank': rank | |
| }) | |
| all_x_categories.append(x_category) | |
| all_names.add(name) | |
| # Create traces for each model/method | |
| fig = go.Figure() | |
| for name in sorted(all_names): | |
| x_vals = [] | |
| y_vals = [] | |
| for metric in selected_metrics: | |
| # Find this model/method's data for this metric | |
| for data in metric_rankings[metric]: | |
| if data['name'] == name: | |
| x_vals.append(data['x_category']) | |
| y_vals.append(data['score']) | |
| break | |
| if x_vals: # Only add if has data | |
| color = color_map.get(name, fallback_color) | |
| fig.add_trace(go.Bar( | |
| name=name, | |
| x=x_vals, | |
| y=y_vals, | |
| marker_color=color, | |
| hovertemplate="<b>%{fullData.name}</b><br>Score: %{y:.1f}<extra></extra>" | |
| )) | |
| # Update layout | |
| fig.update_layout( | |
| title=dict( | |
| text=f"Top {top_n} Performers by Metric", | |
| x=0.5, | |
| xanchor='center', | |
| font=dict(size=18) | |
| ), | |
| xaxis_title="Metrics (Ranked by Performance)", | |
| yaxis_title="Normalized Score", | |
| xaxis=dict( | |
| categoryorder='array', | |
| categoryarray=all_x_categories, | |
| tickangle=0 | |
| ), | |
| yaxis=dict(range=[0, 100]), | |
| barmode='group', | |
| bargap=0.15, | |
| bargroupgap=0.1, | |
| height=550, | |
| margin=dict(l=60, r=200, t=80, b=80), | |
| legend=dict( | |
| font=dict(size=11), | |
| title=dict(text="Models/Methods 💡", font=dict(size=12)), | |
| itemsizing='trace', | |
| x=1.05, | |
| y=1, | |
| xanchor='left', | |
| yanchor='top', | |
| bgcolor='rgba(255,255,255,0.6)', | |
| bordercolor='gray', | |
| borderwidth=1 | |
| ) | |
| ) | |
| return fig | |
| def create_horizontal_bar_chart( | |
| df: pd.DataFrame, | |
| metric: str, | |
| color_map: Optional[Dict[str, str]] = None | |
| ) -> go.Figure: | |
| """ | |
| Create horizontal bar chart for single metric details view. | |
| Adapted from lmGAME_bench's create_horizontal_bar_chart(). | |
| Args: | |
| df: DataFrame with scores | |
| metric: Metric name (e.g., "Recall") | |
| color_map: Dictionary mapping names to colors | |
| Returns: | |
| Plotly Figure with horizontal bar chart | |
| Features: | |
| - Y-axis: Model/method names (sorted by score, descending) | |
| - X-axis: Raw accuracy score (0-1 range) | |
| - Uses raw scores, not normalized | |
| """ | |
| if df.empty or metric not in df.columns: | |
| fig = go.Figure() | |
| fig.update_layout(title=f"No data available for {metric}") | |
| return fig | |
| # Load colors if not provided | |
| if color_map is None: | |
| color_map, fallback_color = load_model_colors() | |
| else: | |
| fallback_color = '#808080' | |
| # Filter and sort | |
| metric_df = df[df[metric].notna()].copy() | |
| metric_df = metric_df.sort_values(by=metric, ascending=True) # Lowest at top | |
| if metric_df.empty: | |
| fig = go.Figure() | |
| fig.update_layout(title=f"No valid data for {metric}") | |
| return fig | |
| # Create bar chart | |
| colors = [color_map.get(name, fallback_color) for name in metric_df['Name']] | |
| fig = go.Figure( | |
| go.Bar( | |
| y=metric_df['Name'], | |
| x=metric_df[metric], | |
| orientation='h', | |
| marker=dict( | |
| color=colors, | |
| line=dict(color='#2c3e50', width=1) | |
| ), | |
| hovertemplate='%{y}<br>Accuracy: %{x:.4f}<extra></extra>' | |
| ) | |
| ) | |
| # Update layout | |
| fig.update_layout( | |
| title=dict( | |
| text=f'{metric} - Detailed Rankings', | |
| x=0.5, | |
| xanchor='center', | |
| font=dict(size=18) | |
| ), | |
| xaxis_title="Accuracy", | |
| yaxis_title="Model/Method", | |
| xaxis=dict( | |
| range=[0, 1], | |
| gridcolor='#e0e0e0' | |
| ), | |
| plot_bgcolor='rgba(0,0,0,0)', | |
| paper_bgcolor='rgba(0,0,0,0)', | |
| font=dict(color='#2c3e50'), | |
| height=max(400, len(metric_df) * 30), # Dynamic height based on entries | |
| margin=dict(l=200, r=40, t=80, b=60), | |
| showlegend=False | |
| ) | |
| return fig | |
| def create_multi_metric_bar_chart( | |
| df: pd.DataFrame, | |
| selected_metrics: List[str], | |
| color_map: Optional[Dict[str, str]] = None | |
| ) -> go.Figure: | |
| """ | |
| Create grouped horizontal bar chart showing multiple metrics for each model/method. | |
| Args: | |
| df: DataFrame with scores | |
| selected_metrics: List of metrics to display (e.g., ["Recall", "Causal Inference"]) | |
| color_map: Dictionary mapping names to colors | |
| Returns: | |
| Plotly Figure with grouped horizontal bar chart | |
| Features: | |
| - Y-axis: Model/method names | |
| - X-axis: Raw accuracy score (0-1 range) | |
| - Multiple bars per model/method (one per selected metric) | |
| - Sorted by average score across selected metrics | |
| """ | |
| if df.empty or not selected_metrics: | |
| fig = go.Figure() | |
| fig.update_layout(title="No data available") | |
| return fig | |
| # Check if all selected metrics exist | |
| missing_metrics = [m for m in selected_metrics if m not in df.columns] | |
| if missing_metrics: | |
| fig = go.Figure() | |
| fig.update_layout(title=f"Missing metrics: {', '.join(missing_metrics)}") | |
| return fig | |
| # Filter to entries that have at least one selected metric | |
| metric_df = df.copy() | |
| metric_df = metric_df[metric_df[selected_metrics].notna().any(axis=1)] | |
| if metric_df.empty: | |
| fig = go.Figure() | |
| fig.update_layout(title="No valid data for selected metrics") | |
| return fig | |
| # Calculate average score across selected metrics for sorting | |
| metric_df['avg_score'] = metric_df[selected_metrics].mean(axis=1) | |
| metric_df = metric_df.sort_values(by='avg_score', ascending=True) # Lowest at top | |
| # Use single base color with gradient based on capability | |
| base_color = "#636EFA" # Blue color | |
| # Normalize avg_score to create gradient (0.3 to 1.0 range for visibility) | |
| min_score = metric_df['avg_score'].min() | |
| max_score = metric_df['avg_score'].max() | |
| score_range = max_score - min_score if max_score > min_score else 1 | |
| # Create color gradient based on model capability (higher score = deeper color) | |
| def get_gradient_color(score, min_val, max_val, score_range): | |
| """Generate color with gradient based on score""" | |
| # Normalize to 0-1 range, then scale to 0.3-1.0 for better visibility | |
| normalized = (score - min_val) / score_range if score_range > 0 else 0.5 | |
| intensity = 0.3 + (normalized * 0.7) # Range: 0.3 (light) to 1.0 (deep) | |
| # Convert base color to RGB and apply intensity with 50% opacity | |
| hex_color = base_color.lstrip('#') | |
| r = int(hex_color[0:2], 16) | |
| g = int(hex_color[2:4], 16) | |
| b = int(hex_color[4:6], 16) | |
| # Apply intensity to RGB values | |
| r = int(255 - (255 - r) * intensity) | |
| g = int(255 - (255 - g) * intensity) | |
| b = int(255 - (255 - b) * intensity) | |
| return f'rgba({r}, {g}, {b}, 0.5)' # 50% transparency | |
| # Create grouped bar chart | |
| fig = go.Figure() | |
| for metric in selected_metrics: | |
| # Create color array for each model based on their avg_score | |
| colors = [ | |
| get_gradient_color(row['avg_score'], min_score, max_score, score_range) | |
| for _, row in metric_df.iterrows() | |
| ] | |
| fig.add_trace(go.Bar( | |
| name=metric, | |
| y=metric_df['Name'], | |
| x=metric_df[metric], | |
| orientation='h', | |
| marker=dict( | |
| color=colors, | |
| line=dict(color='#2c3e50', width=0.5) | |
| ), | |
| hovertemplate=f'<b>%{{y}}</b><br>{metric}: %{{x:.4f}}<extra></extra>' | |
| )) | |
| # Update layout | |
| fig.update_layout( | |
| title=dict( | |
| text=f'Detailed Comparison - {", ".join(selected_metrics)}', | |
| x=0.5, | |
| xanchor='center', | |
| font=dict(size=18) | |
| ), | |
| xaxis_title="Accuracy", | |
| yaxis_title="Model/Method", | |
| xaxis=dict( | |
| range=[0, 1], | |
| gridcolor='#e0e0e0' | |
| ), | |
| barmode='group', | |
| plot_bgcolor='rgba(0,0,0,0)', | |
| paper_bgcolor='rgba(0,0,0,0)', | |
| font=dict(color='#2c3e50'), | |
| height=max(500, len(metric_df) * 40), # Dynamic height | |
| margin=dict(l=200, r=40, t=80, b=80), | |
| legend=dict( | |
| orientation="h", | |
| yanchor="bottom", | |
| y=1.02, | |
| xanchor="center", | |
| x=0.5, | |
| font=dict(size=12) | |
| ) | |
| ) | |
| return fig | |