Spaces:
Running
Running
| """ | |
| PazaBench Visualization Functions for Gradio Integration | |
| """ | |
| import os | |
| import pandas as pd | |
| import plotly.graph_objects as go | |
| from src.constants import ( | |
| COUNTRY_NAMES, | |
| LANGUAGE_COUNTRY_MAP, | |
| LANGUAGE_SAMPLE_COUNTS, | |
| ) | |
| from src.model_counts import MODEL_PARAMETER_COUNTS | |
| # Load model family colors from CSV | |
| def _load_model_family_colors() -> tuple[dict[str, str], dict[str, str]]: | |
| """ | |
| Load color mappings from the model_family_colors.csv file. | |
| Returns: | |
| - model_family_colors: dict mapping model_family -> color (first color for that family) | |
| - model_id_colors: dict mapping model_id -> color | |
| """ | |
| csv_path = os.path.join(os.path.dirname(__file__), 'display', 'model_family_colors.csv') | |
| model_family_colors = {} | |
| model_id_colors = {} | |
| try: | |
| df = pd.read_csv(csv_path) | |
| for _, row in df.iterrows(): | |
| family = row['Model Families'] | |
| model_id = row['Model ID'] | |
| color = row['Color'] | |
| # Store first color encountered for each family (using normalized key) | |
| normalized_family = _normalize_family_name(family) | |
| if normalized_family not in model_family_colors: | |
| model_family_colors[normalized_family] = color | |
| # Store color for each model ID | |
| model_id_colors[model_id] = color | |
| except Exception as e: | |
| print(f"Warning: Could not load model family colors: {e}") | |
| return model_family_colors, model_id_colors | |
| def _normalize_family_name(family: str) -> str: | |
| """Normalize family name for consistent lookup (lowercase, remove underscores/dashes).""" | |
| return family.lower().replace('_', '').replace('-', '').replace(' ', '') | |
| MODEL_FAMILY_COLORS, MODEL_ID_COLORS = _load_model_family_colors() | |
| def _get_color_for_family(family: str) -> str: | |
| """Get color for a model family, with fallback.""" | |
| normalized = _normalize_family_name(family) | |
| return MODEL_FAMILY_COLORS.get(normalized, '#888888') | |
| def _get_color_for_model(model_id: str) -> str: | |
| """Get color for a specific model ID, with fallback to family color.""" | |
| if model_id in MODEL_ID_COLORS: | |
| return MODEL_ID_COLORS[model_id] | |
| # Try to find family and use family color | |
| for family, color in MODEL_FAMILY_COLORS.items(): | |
| if family.lower() in model_id.lower(): | |
| return color | |
| return '#888888' | |
| def _remove_wer_outliers(df: pd.DataFrame, multiplier: float = 1.5) -> pd.DataFrame: | |
| """ | |
| Remove WER outliers using IQR method for cleaner visualizations. | |
| Only removes HIGH outliers (poor performers), keeps LOW outliers (best performers). | |
| Args: | |
| df: DataFrame with 'wer' column | |
| multiplier: IQR multiplier (default 1.5 for standard outlier detection) | |
| Returns: | |
| DataFrame with high outliers removed | |
| """ | |
| if df.empty or 'wer' not in df.columns: | |
| return df | |
| Q1 = df['wer'].quantile(0.25) | |
| Q3 = df['wer'].quantile(0.75) | |
| IQR = Q3 - Q1 | |
| # Only remove HIGH outliers (poor performers), keep LOW outliers (best performers) | |
| # Lower WER is better, so we don't want to remove low-WER entries | |
| upper_bound = Q3 + multiplier * IQR | |
| return df[df['wer'] <= upper_bound] | |
| def create_language_coverage_chart(selected_languages: list[str] | None = None) -> go.Figure: | |
| """ | |
| Create a horizontal bar chart showing sample counts for each language in PazaBench. | |
| Languages are sorted by sample count (descending). | |
| Selected languages are highlighted with a different color. | |
| Args: | |
| selected_languages: List of languages to highlight (None = no highlighting) | |
| """ | |
| # Create dataframe from language sample counts | |
| df = pd.DataFrame([ | |
| {"Language": lang, "Sample Count": count} | |
| for lang, count in LANGUAGE_SAMPLE_COUNTS.items() | |
| ]) | |
| # Sort by sample count descending | |
| df = df.sort_values("Sample Count", ascending=True) # ascending=True for horizontal bar (bottom to top) | |
| # Determine colors - highlight selected languages | |
| if selected_languages and len(selected_languages) > 0: | |
| colors = [ | |
| "#0078D4" if lang in selected_languages else "#D0E8FF" | |
| for lang in df["Language"] | |
| ] | |
| # Add border to selected bars | |
| line_widths = [2 if lang in selected_languages else 0 for lang in df["Language"]] | |
| line_colors = ["#005A9E" if lang in selected_languages else "rgba(0,0,0,0)" for lang in df["Language"]] | |
| else: | |
| colors = "#8CD0FF" # Solid blue color matching theme | |
| line_widths = 0 | |
| line_colors = "rgba(0,0,0,0)" | |
| fig = go.Figure(go.Bar( | |
| y=df["Language"], | |
| x=df["Sample Count"], | |
| orientation='h', | |
| marker=dict( | |
| color=colors, | |
| line=dict(width=line_widths, color=line_colors) if selected_languages else None, | |
| ), | |
| text=None, # Remove text labels above bars | |
| textposition='none', | |
| hovertemplate='<b>%{y}</b><br>Samples: %{x:,}<extra></extra>', | |
| )) | |
| # Build title | |
| if selected_languages and len(selected_languages) > 0: | |
| title_text = f"Language Coverage in PazaBench ({len(selected_languages)} selected)" | |
| else: | |
| title_text = "Language Coverage in PazaBench" | |
| fig.update_layout( | |
| title=dict( | |
| text=title_text, | |
| font=dict(size=16), | |
| x=0.5 | |
| ), | |
| xaxis_title="Number of Samples", | |
| yaxis_title="", | |
| height=800, | |
| autosize=True, | |
| margin=dict(l=120, r=50, t=60, b=40), | |
| template='plotly_white', | |
| showlegend=False, | |
| ) | |
| return fig | |
| def create_language_location_map(languages: str | list[str] | None = None) -> go.Figure: | |
| """ | |
| Create an interactive choropleth map of Africa showing where specific language(s) exist. | |
| If no language is selected, shows a light overview map with all PazaBench countries highlighted. | |
| Args: | |
| languages: The language(s) to highlight on the map (single string or list) | |
| """ | |
| fig = go.Figure() | |
| # Normalize to list | |
| if isinstance(languages, str): | |
| languages = [languages] | |
| if languages and len(languages) > 0: | |
| # Get countries where these languages exist | |
| all_countries = set() | |
| country_language_map = {} # Track which languages are spoken in each country | |
| for lang in languages: | |
| if lang in LANGUAGE_COUNTRY_MAP: | |
| for code in LANGUAGE_COUNTRY_MAP[lang]: | |
| all_countries.add(code) | |
| if code not in country_language_map: | |
| country_language_map[code] = [] | |
| country_language_map[code].append(lang) | |
| if all_countries: | |
| # Create dataframe with countries that have these languages | |
| df_map = pd.DataFrame([ | |
| { | |
| "country_code": code, | |
| "country_name": COUNTRY_NAMES.get(code, code), | |
| "has_language": 1, | |
| "languages": ", ".join(country_language_map.get(code, [])) | |
| } | |
| for code in all_countries | |
| ]) | |
| # Build hover text | |
| if len(languages) == 1: | |
| hover_template = "<b>%{text}</b><br>" + f"{languages[0]} is spoken here<extra></extra>" | |
| else: | |
| hover_template = "<b>%{text}</b><br>Languages: %{customdata}<extra></extra>" | |
| fig.add_trace(go.Choropleth( | |
| locations=df_map["country_code"], | |
| z=df_map["has_language"], | |
| text=df_map["country_name"], | |
| customdata=df_map["languages"], | |
| hovertemplate=hover_template, | |
| colorscale=[[0, "#8CD0FF"], [1, "#8CD0FF"]], # Solid blue color | |
| showscale=False, | |
| marker_line_color="white", | |
| marker_line_width=0.5, | |
| )) | |
| if len(languages) == 1: | |
| title_text = f"Where {languages[0]} is Spoken" | |
| else: | |
| title_text = f"Where {len(languages)} Selected Languages are Spoken" | |
| else: | |
| title_text = "Select a Language to Explore" | |
| else: | |
| # Default view: show all countries with PazaBench data lightly highlighted | |
| all_countries = set() | |
| for countries in LANGUAGE_COUNTRY_MAP.values(): | |
| all_countries.update(countries) | |
| df_map = pd.DataFrame([ | |
| { | |
| "country_code": code, | |
| "country_name": COUNTRY_NAMES.get(code, code), | |
| "in_pazabench": 1 | |
| } | |
| for code in all_countries | |
| ]) | |
| fig.add_trace(go.Choropleth( | |
| locations=df_map["country_code"], | |
| z=df_map["in_pazabench"], | |
| text=df_map["country_name"], | |
| hovertemplate="<b>%{text}</b><br>Has PazaBench data<extra></extra>", | |
| colorscale=[[0, "#E8F4FF"], [1, "#E8F4FF"]], # Very light blue matching theme | |
| showscale=False, | |
| marker_line_color="#B3D9FF", | |
| marker_line_width=0.5, | |
| )) | |
| title_text = "Select a Language to Explore" | |
| fig.update_geos( | |
| visible=True, | |
| resolution=50, | |
| scope="africa", | |
| showcountries=True, | |
| countrycolor="lightgray", | |
| showcoastlines=True, | |
| coastlinecolor="gray", | |
| showland=True, | |
| landcolor="#f5f5f5", | |
| showocean=True, | |
| oceancolor="#e3f2fd", | |
| showlakes=True, | |
| lakecolor="#e3f2fd", | |
| projection_type="natural earth", | |
| center=dict(lat=5, lon=20), | |
| ) | |
| fig.update_layout( | |
| title=dict( | |
| text=title_text, | |
| font=dict(size=18), | |
| x=0.5 | |
| ), | |
| height=600, | |
| autosize=True, | |
| margin=dict(l=5, r=5, t=50, b=5), | |
| geo=dict( | |
| bgcolor="rgba(0,0,0,0)", | |
| ) | |
| ) | |
| # Use SVG renderer for better resolution | |
| fig.update_layout( | |
| template="plotly_white", | |
| ) | |
| return fig | |
| def get_language_sample_info(languages: str | list[str] | None = None, asr_df: pd.DataFrame | None = None) -> str: | |
| """ | |
| Get the sample count information for specific language(s) as styled HTML. | |
| Returns HTML for display in Gradio. | |
| Args: | |
| languages: The language(s) to get information for (single string or list) | |
| asr_df: DataFrame with ASR results to extract dataset groups | |
| """ | |
| # Normalize to list | |
| if isinstance(languages, str): | |
| languages = [languages] | |
| if languages and len(languages) > 0: | |
| # Filter to valid languages | |
| valid_languages = [lang for lang in languages if lang in LANGUAGE_SAMPLE_COUNTS] | |
| if valid_languages: | |
| # Aggregate data across all selected languages | |
| total_samples = sum(LANGUAGE_SAMPLE_COUNTS.get(lang, 0) for lang in valid_languages) | |
| all_countries = set() | |
| for lang in valid_languages: | |
| all_countries.update(LANGUAGE_COUNTRY_MAP.get(lang, [])) | |
| country_names = sorted([COUNTRY_NAMES.get(code, code) for code in all_countries]) | |
| # Get dataset groups from ASR data if available | |
| dataset_groups = set() | |
| if asr_df is not None and not asr_df.empty and 'language' in asr_df.columns: | |
| for lang in valid_languages: | |
| lang_data = asr_df[asr_df['language'] == lang] | |
| if not lang_data.empty and 'dataset_group' in lang_data.columns: | |
| dataset_groups.update(lang_data['dataset_group'].unique().tolist()) | |
| dataset_groups = sorted(dataset_groups) | |
| # Build title based on number of languages | |
| if len(valid_languages) == 1: | |
| title = f"🌍 {valid_languages[0]}" | |
| else: | |
| title = f"🌍 {', '.join(sorted(valid_languages))}" | |
| html = f""" | |
| <div style="background: linear-gradient(135deg, #DDF1FF 0%, #ffffff 100%); border-radius: 12px; padding: 20px; border: 1px solid #8CD0FF;"> | |
| <h4 style="margin: 0 0 16px 0; color: #0f172a; font-size: 1.1em; font-weight: 600; border-bottom: 2px solid #8CD0FF; padding-bottom: 8px;"> | |
| {title} | |
| </h4> | |
| <div style="display: grid; gap: 12px;"> | |
| <div style="background: white; padding: 12px 16px; border-radius: 8px; display: flex; justify-content: space-between; align-items: center;"> | |
| <span style="color: #666; font-weight: 500; font-size: 0.9rem;">📊 Total Samples</span> | |
| <span style="color: #0f172a; font-weight: 700; font-size: 0.9rem;">{total_samples:,}</span> | |
| </div> | |
| <div style="background: white; padding: 12px 16px; border-radius: 8px;"> | |
| <div style="color: #666; font-weight: 500; margin-bottom: 6px; font-size: 0.9rem;">📍 Countries ({len(country_names)})</div> | |
| <div style="color: #0f172a; font-weight: 500; font-size: 0.9rem;">{', '.join(country_names) if country_names else 'N/A'}</div> | |
| </div> | |
| <div style="background: white; padding: 12px 16px; border-radius: 8px;"> | |
| <div style="color: #666; font-weight: 500; margin-bottom: 6px; font-size: 0.9rem;">📁 Dataset Sources ({len(dataset_groups)})</div> | |
| <div style="color: #0f172a; font-size: 0.9rem;"> | |
| {(', '.join(dataset_groups) if dataset_groups else '<em>No dataset info available</em>')} | |
| </div> | |
| </div> | |
| </div> | |
| </div> | |
| """ | |
| return html | |
| # Default view - show sample overview summary | |
| total_languages = len(LANGUAGE_SAMPLE_COUNTS) | |
| total_samples = sum(LANGUAGE_SAMPLE_COUNTS.values()) | |
| total_countries = len(set(code for codes in LANGUAGE_COUNTRY_MAP.values() for code in codes)) | |
| html = f""" | |
| <div style="background: linear-gradient(135deg, #f0f9ff 0%, #ffffff 100%); border-radius: 12px; padding: 20px; border: 1px solid #e0e7ef;"> | |
| <h4 style="margin: 0 0 16px 0; color: #0f172a; font-size: 1.1em; font-weight: 600; border-bottom: 2px solid #8CD0FF; padding-bottom: 8px;"> | |
| 📊 Sample Overview | |
| </h4> | |
| <div style="display: grid; gap: 10px; margin-bottom: 16px;"> | |
| <div style="background: white; padding: 10px 14px; border-radius: 8px; display: flex; justify-content: space-between; align-items: center; border: 1px solid #e5e7eb;"> | |
| <span style="color: #666; font-weight: 500; font-size: 0.9rem;">🌍 Languages</span> | |
| <span style="color: #0f172a; font-weight: 700; font-size: 0.9rem;">{total_languages}</span> | |
| </div> | |
| <div style="background: white; padding: 10px 14px; border-radius: 8px; display: flex; justify-content: space-between; align-items: center; border: 1px solid #e5e7eb;"> | |
| <span style="color: #666; font-weight: 500; font-size: 0.9rem;">📊 Total Samples</span> | |
| <span style="color: #0f172a; font-weight: 700; font-size: 0.9rem;">{total_samples:,}</span> | |
| </div> | |
| <div style="background: white; padding: 10px 14px; border-radius: 8px; display: flex; justify-content: space-between; align-items: center; border: 1px solid #e5e7eb;"> | |
| <span style="color: #666; font-weight: 500; font-size: 0.9rem;">📍 Countries</span> | |
| <span style="color: #0f172a; font-weight: 700; font-size: 0.9rem;">{total_countries}</span> | |
| </div> | |
| </div> | |
| <p style="margin: 14px 0 0 0; font-size: 0.85rem; color: #666; font-style: italic; text-align: center;"> | |
| 👈 Select a language on the left to explore its details | |
| </p> | |
| </div> | |
| """ | |
| return html | |
| def get_language_sample_info_df(language: str | None = None) -> pd.DataFrame: | |
| """ | |
| Legacy function - returns DataFrame for backward compatibility. | |
| """ | |
| if language and language in LANGUAGE_SAMPLE_COUNTS: | |
| sample_count = LANGUAGE_SAMPLE_COUNTS[language] | |
| countries = LANGUAGE_COUNTRY_MAP.get(language, []) | |
| country_names = [COUNTRY_NAMES.get(code, code) for code in countries] | |
| df = pd.DataFrame({ | |
| "Metric": ["Language", "Total Samples", "Countries"], | |
| "Value": [ | |
| language, | |
| f"{sample_count:,}", | |
| ", ".join(country_names) if country_names else "N/A" | |
| ] | |
| }) | |
| else: | |
| df = pd.DataFrame({ | |
| "Metric": ["Language", "Total Samples", "Countries"], | |
| "Value": ["Select a language", "-", "-"] | |
| }) | |
| return df | |
| def get_all_languages() -> list[str]: | |
| """Get a sorted list of all languages in PazaBench.""" | |
| return sorted(LANGUAGE_SAMPLE_COUNTS.keys()) | |
| def create_africa_language_map() -> go.Figure: | |
| """ | |
| Create an interactive choropleth map of Africa showing language coverage. | |
| Hover over countries to see the languages spoken there. | |
| """ | |
| # Build country data with languages | |
| country_data = {} | |
| for language, countries in LANGUAGE_COUNTRY_MAP.items(): | |
| for country_code in countries: | |
| if country_code not in country_data: | |
| country_data[country_code] = { | |
| "languages": [], | |
| "count": 0, | |
| "country_name": COUNTRY_NAMES.get(country_code, country_code) | |
| } | |
| country_data[country_code]["languages"].append(language) | |
| country_data[country_code]["count"] += 1 | |
| # Create dataframe for plotly | |
| df_map = pd.DataFrame([ | |
| { | |
| "country_code": code, | |
| "country_name": data["country_name"], | |
| "language_count": data["count"], | |
| "languages": ", ".join(sorted(data["languages"])) | |
| } | |
| for code, data in country_data.items() | |
| ]) | |
| fig = go.Figure(go.Choropleth( | |
| locations=df_map["country_code"], | |
| z=df_map["language_count"], | |
| text=df_map["country_name"], | |
| customdata=df_map[["languages", "language_count"]], | |
| hovertemplate="<b>%{text}</b><br>" + | |
| "Languages: %{customdata[1]}<br>" + | |
| "<i>%{customdata[0]}</i><extra></extra>", | |
| colorscale=[ | |
| [0, "#DDF1FF"], | |
| [0.25, "#B5E0FF"], | |
| [0.5, "#8CD0FF"], | |
| [0.75, "#6BC0F5"], | |
| [1, "#4AAFEB"] | |
| ], | |
| showscale=True, | |
| colorbar=dict( | |
| title="Languages", | |
| tickmode="linear", | |
| tick0=1, | |
| dtick=1 | |
| ), | |
| marker_line_color="white", | |
| marker_line_width=0.5, | |
| )) | |
| fig.update_geos( | |
| visible=True, | |
| resolution=50, | |
| scope="africa", | |
| showcountries=True, | |
| countrycolor="lightgray", | |
| showcoastlines=True, | |
| coastlinecolor="gray", | |
| showland=True, | |
| landcolor="#f5f5f5", | |
| showocean=True, | |
| oceancolor="#e3f2fd", | |
| showlakes=True, | |
| lakecolor="#e3f2fd", | |
| projection_type="natural earth", | |
| center=dict(lat=5, lon=20), | |
| ) | |
| fig.update_layout( | |
| title=dict( | |
| text="African Languages in PazaBench", | |
| font=dict(size=18), | |
| x=0.5 | |
| ), | |
| height=600, | |
| autosize=True, | |
| margin=dict(l=5, r=5, t=50, b=5), | |
| geo=dict( | |
| bgcolor="rgba(0,0,0,0)", | |
| ), | |
| template="plotly_white", | |
| ) | |
| return fig | |
| def create_model_leaderboard(df: pd.DataFrame, languages: list[str] | None = None, top_n_models: int = 15) -> go.Figure: | |
| """ | |
| Visualization 1: Model Family / Individual Model Performance Leaderboard | |
| - When no language filter: Shows model families (aggregated) | |
| - When language(s) selected: Shows top N individual models | |
| Outliers are removed for cleaner visualization. | |
| Args: | |
| df: DataFrame with evaluation results | |
| languages: List of languages to filter by (None = all languages) | |
| top_n_models: Number of top individual models to show when languages are filtered (default: 15) | |
| """ | |
| # Apply language filter if provided | |
| filtered_df = df.copy() | |
| if languages: | |
| filtered_df = filtered_df[filtered_df['language'].isin(languages)] | |
| # Remove WER outliers for cleaner visualization | |
| filtered_df = _remove_wer_outliers(filtered_df) | |
| # Determine mode: individual models if languages selected, otherwise model families | |
| show_individual_models = languages is not None and len(languages) > 0 | |
| if show_individual_models: | |
| # Individual model mode: show top N models by median WER | |
| model_perf = filtered_df.groupby(['model_family', 'model']).agg({ | |
| 'wer': ['median', 'std', 'count'], | |
| 'cer': 'median', | |
| 'rtfx': 'median' | |
| }).reset_index() | |
| # Get unique sample counts per model | |
| unique_samples = filtered_df.groupby(['model_family', 'model', 'language', 'dataset_group'])['num_samples'].first().reset_index() | |
| model_samples = unique_samples.groupby(['model_family', 'model'])['num_samples'].sum().reset_index() | |
| model_samples.columns = ['model_family', 'model', 'total_samples'] | |
| model_perf.columns = ['model_family', 'model', 'wer_median', 'wer_std', 'count', 'cer_median', 'rtfx_median'] | |
| model_perf = model_perf.merge(model_samples, on=['model_family', 'model'], how='left') | |
| model_perf = model_perf.sort_values('wer_median').head(top_n_models) | |
| # Create short model name for display | |
| model_perf['model_short'] = model_perf['model'].apply( | |
| lambda x: x.split('/')[-1] if '/' in x else x | |
| ) | |
| # Get colors for each model based on family | |
| model_perf['color'] = model_perf['model_family'].apply(_get_color_for_family) | |
| fig = go.Figure() | |
| fig.add_trace(go.Bar( | |
| y=model_perf['model_short'], | |
| x=model_perf['wer_median'], | |
| orientation='h', | |
| marker=dict(color=model_perf['color']), | |
| text=model_perf['wer_median'].round(2), | |
| textposition='outside', | |
| hovertemplate=( | |
| '<b>%{customdata[0]}</b><br>' + | |
| '<i>Family: %{customdata[1]}</i><br><br>' + | |
| '<b>Median WER:</b> %{x:.3f}<br>' + | |
| '<b>RTFx:</b> %{customdata[2]:.1f}<br>' + | |
| '<b>Evaluations:</b> %{customdata[3]}<br>' + | |
| '<b>Samples:</b> %{customdata[4]:,}<extra></extra>' | |
| ), | |
| customdata=model_perf[['model', 'model_family', 'rtfx_median', 'count', 'total_samples']] | |
| )) | |
| # Build title with language info (languages is guaranteed to be non-empty here) | |
| lang_list = languages if languages else [] | |
| lang_str = ", ".join(lang_list[:3]) + ("..." if len(lang_list) > 3 else "") | |
| title_text = f"Top {min(top_n_models, len(model_perf))} Models for {lang_str}" | |
| else: | |
| # Model family mode (original behavior) | |
| model_perf = filtered_df.groupby('model_family').agg({ | |
| 'wer': ['median', 'std', 'count'], | |
| 'cer': 'median', | |
| 'rtfx': 'median' | |
| }).reset_index() | |
| # Get unique sample counts per model family (avoid double-counting across models) | |
| unique_samples = filtered_df.groupby(['model_family', 'language', 'dataset_group'])['num_samples'].first().reset_index() | |
| family_samples = unique_samples.groupby('model_family')['num_samples'].sum().reset_index() | |
| family_samples.columns = ['model_family', 'total_samples'] | |
| model_perf.columns = ['model_family', 'wer_median', 'wer_std', 'count', 'cer_median', 'rtfx_median'] | |
| model_perf = model_perf.merge(family_samples, on='model_family', how='left') | |
| model_perf = model_perf.sort_values('wer_median') | |
| # Get colors for each model family | |
| model_perf['color'] = model_perf['model_family'].apply(_get_color_for_family) | |
| fig = go.Figure() | |
| fig.add_trace(go.Bar( | |
| y=model_perf['model_family'], | |
| x=model_perf['wer_median'], | |
| orientation='h', | |
| error_x=dict(type='data', array=model_perf['wer_std']), | |
| marker=dict(color=model_perf['color']), | |
| text=model_perf['wer_median'].round(2), | |
| textposition='outside', | |
| hovertemplate=( | |
| '<b>%{y}</b><br><br>' + | |
| '<b>Median WER:</b> %{x:.3f}<br>' + | |
| '<b>Std Dev:</b> %{customdata[0]:.3f}<br>' + | |
| '<b>RTFx:</b> %{customdata[1]:.1f}<br>' + | |
| '<b>Evaluations:</b> %{customdata[2]}<br>' + | |
| '<b>Samples:</b> %{customdata[3]:,}<extra></extra>' | |
| ), | |
| customdata=model_perf[['wer_std', 'rtfx_median', 'count', 'total_samples']] | |
| )) | |
| title_text = "Model Family Performance Leaderboard" | |
| # Calculate dynamic height based on number of items | |
| num_items = len(model_perf) | |
| height = max(400, min(700, 100 + num_items * 35)) | |
| fig.update_layout( | |
| title=title_text, | |
| xaxis_title="Word Error Rate (WER)", | |
| yaxis_title="", | |
| height=height, | |
| autosize=True, | |
| showlegend=False, | |
| template='plotly_white', | |
| margin=dict(l=200, r=30, t=60, b=60) | |
| ) | |
| return fig | |
| def create_cer_leaderboard(df: pd.DataFrame, languages: list[str] | None = None, top_n_models: int = 15) -> go.Figure: | |
| """ | |
| Visualization: CER Model Family / Individual Model Performance Leaderboard | |
| - When no language filter: Shows model families (aggregated) | |
| - When language(s) selected: Shows top N individual models | |
| Outliers are removed for cleaner visualization. | |
| Args: | |
| df: DataFrame with evaluation results | |
| languages: List of languages to filter by (None = all languages) | |
| top_n_models: Number of top individual models to show when languages are filtered (default: 15) | |
| """ | |
| # Apply language filter if provided | |
| filtered_df = df.copy() | |
| if languages: | |
| filtered_df = filtered_df[filtered_df['language'].isin(languages)] | |
| # Remove CER outliers for cleaner visualization (similar to WER outlier removal) | |
| if not filtered_df.empty and 'cer' in filtered_df.columns: | |
| Q1 = filtered_df['cer'].quantile(0.25) | |
| Q3 = filtered_df['cer'].quantile(0.75) | |
| IQR = Q3 - Q1 | |
| upper_bound = Q3 + 1.5 * IQR | |
| filtered_df = filtered_df[filtered_df['cer'] <= upper_bound] | |
| # Determine mode: individual models if languages selected, otherwise model families | |
| show_individual_models = languages is not None and len(languages) > 0 | |
| if show_individual_models: | |
| # Individual model mode: show top N models by median CER | |
| model_perf = filtered_df.groupby(['model_family', 'model']).agg({ | |
| 'cer': ['median', 'std', 'count'], | |
| 'wer': 'median', | |
| 'rtfx': 'median' | |
| }).reset_index() | |
| # Get unique sample counts per model | |
| unique_samples = filtered_df.groupby(['model_family', 'model', 'language', 'dataset_group'])['num_samples'].first().reset_index() | |
| model_samples = unique_samples.groupby(['model_family', 'model'])['num_samples'].sum().reset_index() | |
| model_samples.columns = ['model_family', 'model', 'total_samples'] | |
| model_perf.columns = ['model_family', 'model', 'cer_median', 'cer_std', 'count', 'wer_median', 'rtfx_median'] | |
| model_perf = model_perf.merge(model_samples, on=['model_family', 'model'], how='left') | |
| model_perf = model_perf.sort_values('cer_median').head(top_n_models) | |
| # Create short model name for display | |
| model_perf['model_short'] = model_perf['model'].apply( | |
| lambda x: x.split('/')[-1] if '/' in x else x | |
| ) | |
| # Get colors for each model based on family | |
| model_perf['color'] = model_perf['model_family'].apply(_get_color_for_family) | |
| fig = go.Figure() | |
| fig.add_trace(go.Bar( | |
| y=model_perf['model_short'], | |
| x=model_perf['cer_median'], | |
| orientation='h', | |
| marker=dict(color=model_perf['color']), | |
| text=model_perf['cer_median'].round(2), | |
| textposition='outside', | |
| hovertemplate=( | |
| '<b>%{customdata[0]}</b><br>' + | |
| '<i>Family: %{customdata[1]}</i><br><br>' + | |
| '<b>Median CER:</b> %{x:.3f}<br>' + | |
| '<b>WER:</b> %{customdata[2]:.3f}<br>' + | |
| '<b>RTFx:</b> %{customdata[3]:.1f}<br>' + | |
| '<b>Evaluations:</b> %{customdata[4]}<br>' + | |
| '<b>Samples:</b> %{customdata[5]:,}<extra></extra>' | |
| ), | |
| customdata=model_perf[['model', 'model_family', 'wer_median', 'rtfx_median', 'count', 'total_samples']] | |
| )) | |
| # Build title with language info (languages is guaranteed to be non-empty here) | |
| lang_list = languages if languages else [] | |
| lang_str = ", ".join(lang_list[:3]) + ("..." if len(lang_list) > 3 else "") | |
| title_text = f"Top {min(top_n_models, len(model_perf))} Models by CER for {lang_str}" | |
| else: | |
| # Model family mode (original behavior) | |
| model_perf = filtered_df.groupby('model_family').agg({ | |
| 'cer': ['median', 'std', 'count'], | |
| 'wer': 'median', | |
| 'rtfx': 'median' | |
| }).reset_index() | |
| # Get unique sample counts per model family (avoid double-counting across models) | |
| unique_samples = filtered_df.groupby(['model_family', 'language', 'dataset_group'])['num_samples'].first().reset_index() | |
| family_samples = unique_samples.groupby('model_family')['num_samples'].sum().reset_index() | |
| family_samples.columns = ['model_family', 'total_samples'] | |
| model_perf.columns = ['model_family', 'cer_median', 'cer_std', 'count', 'wer_median', 'rtfx_median'] | |
| model_perf = model_perf.merge(family_samples, on='model_family', how='left') | |
| model_perf = model_perf.sort_values('cer_median') | |
| # Get colors for each model family | |
| model_perf['color'] = model_perf['model_family'].apply(_get_color_for_family) | |
| fig = go.Figure() | |
| fig.add_trace(go.Bar( | |
| y=model_perf['model_family'], | |
| x=model_perf['cer_median'], | |
| orientation='h', | |
| error_x=dict(type='data', array=model_perf['cer_std']), | |
| marker=dict(color=model_perf['color']), | |
| text=model_perf['cer_median'].round(2), | |
| textposition='outside', | |
| hovertemplate=( | |
| '<b>%{y}</b><br><br>' + | |
| '<b>Median CER:</b> %{x:.3f}<br>' + | |
| '<b>Std Dev:</b> %{customdata[0]:.3f}<br>' + | |
| '<b>WER:</b> %{customdata[1]:.3f}<br>' + | |
| '<b>RTFx:</b> %{customdata[2]:.1f}<br>' + | |
| '<b>Evaluations:</b> %{customdata[3]}<br>' + | |
| '<b>Samples:</b> %{customdata[4]:,}<extra></extra>' | |
| ), | |
| customdata=model_perf[['cer_std', 'wer_median', 'rtfx_median', 'count', 'total_samples']] | |
| )) | |
| title_text = "Model Family Performance by CER" | |
| # Calculate dynamic height based on number of items | |
| num_items = len(model_perf) | |
| height = max(400, min(700, 100 + num_items * 35)) | |
| fig.update_layout( | |
| title=title_text, | |
| xaxis_title="Character Error Rate (CER)", | |
| yaxis_title="", | |
| height=height, | |
| autosize=True, | |
| showlegend=False, | |
| template='plotly_white', | |
| margin=dict(l=200, r=30, t=60, b=60) | |
| ) | |
| return fig | |
| def create_speed_accuracy_scatter(df: pd.DataFrame, view_mode: str = "model_family", languages: list[str] | None = None) -> go.Figure: | |
| """ | |
| Visualization 2: Speed vs Accuracy Tradeoff | |
| Scatter plot showing the relationship between WER and RTFx with quadrants. | |
| Outliers are removed for cleaner visualization. | |
| Args: | |
| df: DataFrame with evaluation results | |
| view_mode: Either "model_family" (bubbles same size per family, color by family) or | |
| "individual_model" (bubble size = model params, color = model family) | |
| languages: List of languages to filter by (None = all languages) | |
| """ | |
| # Apply language filter if provided | |
| if languages: | |
| df = df[df['language'].isin(languages)] | |
| # Remove WER outliers for cleaner visualization | |
| df = _remove_wer_outliers(df) | |
| if view_mode == "individual_model": | |
| # Aggregate by individual model | |
| model_agg = df.groupby(['model_family', 'model']).agg({ | |
| 'wer': 'median', | |
| 'rtfx': 'median', | |
| 'cer': 'median', | |
| }).reset_index() | |
| # Get unique sample counts per model | |
| unique_samples = df.groupby(['model_family', 'model', 'language', 'dataset_group'])['num_samples'].first().reset_index() | |
| model_samples = unique_samples.groupby(['model_family', 'model'])['num_samples'].sum().reset_index() | |
| model_samples.columns = ['model_family', 'model', 'num_samples'] | |
| model_agg = model_agg.merge(model_samples, on=['model_family', 'model'], how='left') | |
| # Get parameter count for each individual model | |
| model_agg['params'] = model_agg['model'].apply( | |
| lambda m: MODEL_PARAMETER_COUNTS.get(m, 500_000_000) # Default 500M | |
| ) | |
| model_agg['params_billions'] = model_agg['params'] / 1_000_000_000 | |
| model_agg['params_display'] = model_agg['params'].apply( | |
| lambda x: f"{x/1_000_000_000:.1f}B" if x >= 1_000_000_000 else f"{x/1_000_000:.0f}M" | |
| ) | |
| # Create short model name for display | |
| model_agg['model_short'] = model_agg['model'].apply( | |
| lambda x: x.split('/')[-1] if '/' in x else x | |
| ) | |
| fig = go.Figure() | |
| # Add scatter traces for each model family | |
| for family in model_agg['model_family'].unique(): | |
| family_data = model_agg[model_agg['model_family'] == family] | |
| family_color = _get_color_for_family(family) | |
| fig.add_trace(go.Scatter( | |
| x=family_data['wer'], | |
| y=family_data['rtfx'], | |
| mode='markers', | |
| name=family, | |
| marker=dict( | |
| size=family_data['params'] / family_data['params'].max() * 50 + 10, | |
| sizemode='diameter', | |
| sizemin=8, | |
| color=family_color | |
| ), | |
| customdata=family_data[['model_short', 'cer', 'num_samples', 'params_display']].values, | |
| hovertemplate=( | |
| '<b>%{customdata[0]}</b><br><br>' + | |
| '<b>WER:</b> %{x:.3f}<br>' + | |
| '<b>RTFx:</b> %{y:.1f}<br>' + | |
| '<b>CER:</b> %{customdata[1]:.3f}<br>' + | |
| '<b>Samples:</b> %{customdata[2]:,}<br>' + | |
| '<b>Parameters:</b> %{customdata[3]}<extra></extra>' | |
| ) | |
| )) | |
| title_text = "Speed vs Accuracy Tradeoff by Individual Model" | |
| else: | |
| # Original behavior: aggregate by model family | |
| model_agg = df.groupby('model_family').agg({ | |
| 'wer': 'median', | |
| 'rtfx': 'median', | |
| 'cer': 'median', | |
| 'model': 'first' # Get a representative model name for parameter lookup | |
| }).reset_index() | |
| # Get unique sample counts per model family | |
| unique_samples = df.groupby(['model_family', 'language', 'dataset_group'])['num_samples'].first().reset_index() | |
| family_samples = unique_samples.groupby('model_family')['num_samples'].sum().reset_index() | |
| family_samples.columns = ['model_family', 'num_samples'] | |
| model_agg = model_agg.merge(family_samples, on='model_family', how='left') | |
| # For model family view, use a uniform size (no bubble size variation) | |
| # Use a constant for params to make bubbles the same size | |
| model_agg['params'] = 1_000_000_000 # Use constant 1B for uniform bubble size | |
| model_agg['params_display'] = 'N/A' # Not applicable in family view | |
| fig = go.Figure() | |
| # Add scatter traces for each model family | |
| for family in model_agg['model_family'].unique(): | |
| family_data = model_agg[model_agg['model_family'] == family] | |
| family_color = _get_color_for_family(family) | |
| fig.add_trace(go.Scatter( | |
| x=family_data['wer'], | |
| y=family_data['rtfx'], | |
| mode='markers', | |
| name=family, | |
| marker=dict( | |
| size=20, | |
| sizemode='diameter', | |
| color=family_color | |
| ), | |
| customdata=family_data[['cer', 'num_samples']].values, | |
| hovertemplate=( | |
| '<b>' + family + '</b><br><br>' + | |
| '<b>WER:</b> %{x:.3f}<br>' + | |
| '<b>RTFx:</b> %{y:.1f}<br>' + | |
| '<b>CER:</b> %{customdata[0]:.3f}<br>' + | |
| '<b>Samples:</b> %{customdata[1]:,}<extra></extra>' | |
| ) | |
| )) | |
| title_text = "Speed vs Accuracy Tradeoff by Model Family" | |
| # Add quadrant lines at median | |
| median_wer = model_agg['wer'].median() | |
| median_rtfx = model_agg['rtfx'].median() | |
| fig.add_hline(y=median_rtfx, line_dash="dash", line_color="gray", annotation_text="Median RTFx", annotation_position="right") | |
| fig.add_vline(x=median_wer, line_dash="dash", line_color="gray", annotation_text="Median WER", annotation_position="top") | |
| # Add quadrant label - centered in the "Fast & Accurate" quadrant (low WER, high RTFx) | |
| # The ideal quadrant is: x from min to median_wer, y from median_rtfx to max | |
| quadrant_center_x = (model_agg['wer'].min() + median_wer) / 2 | |
| quadrant_center_y = (median_rtfx + model_agg['rtfx'].max()) / 2 | |
| fig.add_annotation( | |
| x=quadrant_center_x, | |
| y=quadrant_center_y, | |
| text="Fast & Accurate ⭐", | |
| showarrow=False, | |
| font=dict(size=12, color="green", family="Arial Black") | |
| ) | |
| fig.update_layout( | |
| title=title_text, | |
| xaxis_title="WER", | |
| yaxis_title="RTFx", | |
| yaxis=dict(rangemode='tozero'), # Ensure y-axis starts at 0 (RTFx can't be negative) | |
| height=550, | |
| autosize=True, | |
| template='plotly_white', | |
| legend=dict( | |
| orientation="h", | |
| yanchor="top", | |
| y=-0.15, | |
| xanchor="center", | |
| x=0.5, | |
| font=dict(size=10) | |
| ), | |
| margin=dict(l=60, r=20, t=50, b=120) | |
| ) | |
| return fig | |
| def create_wer_cer_correlation(df: pd.DataFrame, languages: list[str] | None = None, top_n_models: int | None = None) -> go.Figure: | |
| """ | |
| Visualization 7: WER vs CER Correlation | |
| Scatter plot showing the relationship between word and character error rates. | |
| Defaults to Swahili if no language is specified. | |
| Outliers are removed for cleaner visualization. | |
| Args: | |
| df: DataFrame with evaluation results | |
| languages: List of languages to filter by (defaults to ["Swahili"] if None) | |
| top_n_models: If specified, only show top N models by WER (0 = show all) | |
| """ | |
| # Default to Swahili if no language filter provided | |
| if not languages: | |
| languages = ["Swahili"] | |
| # Apply language filter | |
| filtered_df = df.copy() | |
| filtered_df = filtered_df[filtered_df['language'].isin(languages)] | |
| # Remove WER outliers for cleaner visualization | |
| filtered_df = _remove_wer_outliers(filtered_df) | |
| # Apply top N models filter if specified | |
| if top_n_models and top_n_models > 0: | |
| model_wer = filtered_df.groupby('model')['wer'].median().sort_values() | |
| top_models = model_wer.head(top_n_models).index.tolist() | |
| filtered_df = filtered_df[filtered_df['model'].isin(top_models)] | |
| fig = go.Figure() | |
| # Add scatter traces for each model family | |
| for family in filtered_df['model_family'].unique(): | |
| family_data = filtered_df[filtered_df['model_family'] == family] | |
| family_color = _get_color_for_family(family) | |
| # Normalize size for better visualization | |
| max_samples = filtered_df['num_samples'].max() if not filtered_df.empty else 1 | |
| sizes = (family_data['num_samples'] / max_samples * 25 + 5).values if not family_data.empty else [10] | |
| fig.add_trace(go.Scatter( | |
| x=family_data['wer'], | |
| y=family_data['cer'], | |
| mode='markers', | |
| name=family, | |
| marker=dict( | |
| size=sizes, | |
| sizemode='diameter', | |
| sizemin=5, | |
| opacity=0.6, | |
| color=family_color | |
| ), | |
| customdata=family_data[['language', 'model', 'dataset_group', 'num_samples']].values, | |
| hovertemplate=( | |
| '<b>%{customdata[0]}</b><br><br>' + | |
| '<b>Model:</b> %{customdata[1]}<br>' + | |
| '<b>Dataset:</b> %{customdata[2]}<br>' + | |
| '<b>WER:</b> %{x:.3f}<br>' + | |
| '<b>CER:</b> %{y:.3f}<br>' + | |
| '<b>Samples:</b> %{customdata[3]:,}<extra></extra>' | |
| ) | |
| )) | |
| # Add trendline | |
| if len(filtered_df) > 1: | |
| import numpy as np | |
| z = np.polyfit(filtered_df['wer'], filtered_df['cer'], 1) | |
| p = np.poly1d(z) | |
| x_range = np.linspace(filtered_df['wer'].min(), filtered_df['wer'].max(), 100) | |
| fig.add_trace(go.Scatter( | |
| x=x_range, | |
| y=p(x_range), | |
| mode='lines', | |
| name='Trend', | |
| line=dict(color='gray', dash='dash'), | |
| hoverinfo='skip' | |
| )) | |
| # Calculate correlation | |
| correlation = filtered_df[['wer', 'cer']].corr().iloc[0, 1] if len(filtered_df) > 1 else 0 | |
| # Build title with language info | |
| lang_str = ", ".join(languages[:3]) + ("..." if len(languages) > 3 else "") | |
| title_text = f"WER vs CER Correlation for {lang_str} (r={correlation:.2f})" | |
| fig.update_layout( | |
| title=title_text, | |
| xaxis_title="WER", | |
| yaxis_title="CER", | |
| height=550, | |
| autosize=True, | |
| template='plotly_white', | |
| legend=dict( | |
| orientation="h", | |
| yanchor="top", | |
| y=-0.15, | |
| xanchor="center", | |
| x=0.5, | |
| font=dict(size=10) | |
| ), | |
| margin=dict(l=60, r=20, t=50, b=120), | |
| # Zoom X-axis to useful range (0-1.5) | |
| xaxis=dict(range=[0, 1.5]), | |
| yaxis=dict(range=[0, 1.5]) | |
| ) | |
| return fig | |
| def create_model_consistency(df: pd.DataFrame) -> go.Figure: | |
| """ | |
| Visualization 9: Model Consistency Analysis | |
| Shows coefficient of variation (CV) to measure consistency across languages. | |
| Removes high outliers only using IQR method (keeps best performers). | |
| """ | |
| # Remove only HIGH outliers using IQR method on WER (keep best performers) | |
| Q1 = df['wer'].quantile(0.25) | |
| Q3 = df['wer'].quantile(0.75) | |
| IQR = Q3 - Q1 | |
| upper_bound = Q3 + 1.5 * IQR | |
| df_no_outliers = df[df['wer'] <= upper_bound] | |
| model_variance = df_no_outliers.groupby('model_family').agg({ | |
| 'wer': ['median', 'std', 'count'] | |
| }).reset_index() | |
| model_variance.columns = ['model_family', 'wer_median', 'wer_std', 'count'] | |
| model_variance['cv'] = (model_variance['wer_std'] / model_variance['wer_median'] * 100) | |
| model_variance = model_variance.sort_values('cv') | |
| # Get colors for each model family | |
| model_variance['color'] = model_variance['model_family'].apply(_get_color_for_family) | |
| fig = go.Figure() | |
| fig.add_trace(go.Bar( | |
| y=model_variance['model_family'], | |
| x=model_variance['cv'], | |
| orientation='h', | |
| marker=dict( | |
| color=model_variance['color'] | |
| ), | |
| text=model_variance['cv'].round(1), | |
| textposition='outside', | |
| hovertemplate=( | |
| '<b>%{y}</b><br><br>' + | |
| '<b>Coefficient of Variation:</b> %{x:.1f}%<br>' + | |
| '<b>Median WER:</b> %{customdata[0]:.3f}<br>' + | |
| '<b>Std Dev:</b> %{customdata[1]:.3f}<br>' + | |
| '<b>Evaluations:</b> %{customdata[2]}<extra></extra>' | |
| ), | |
| customdata=model_variance[['wer_median', 'wer_std', 'count']] | |
| )) | |
| fig.update_layout( | |
| title="Model Consistency Ranking (Outliers Removed)", | |
| xaxis_title="Coefficient of Variation (%)", | |
| yaxis_title="Model Family", | |
| height=550, | |
| template='plotly_white', | |
| showlegend=False, | |
| margin=dict(l=200, r=100, t=80, b=80) | |
| ) | |
| return fig | |