""" PazaBench Visualization Functions for Gradio Integration """ import os import pandas as pd import plotly.graph_objects as go from src.constants import ( COUNTRY_NAMES, LANGUAGE_COUNTRY_MAP, LANGUAGE_SAMPLE_COUNTS, ) from src.model_counts import MODEL_PARAMETER_COUNTS # Load model family colors from CSV def _load_model_family_colors() -> tuple[dict[str, str], dict[str, str]]: """ Load color mappings from the model_family_colors.csv file. Returns: - model_family_colors: dict mapping model_family -> color (first color for that family) - model_id_colors: dict mapping model_id -> color """ csv_path = os.path.join(os.path.dirname(__file__), 'display', 'model_family_colors.csv') model_family_colors = {} model_id_colors = {} try: df = pd.read_csv(csv_path) for _, row in df.iterrows(): family = row['Model Families'] model_id = row['Model ID'] color = row['Color'] # Store first color encountered for each family (using normalized key) normalized_family = _normalize_family_name(family) if normalized_family not in model_family_colors: model_family_colors[normalized_family] = color # Store color for each model ID model_id_colors[model_id] = color except Exception as e: print(f"Warning: Could not load model family colors: {e}") return model_family_colors, model_id_colors def _normalize_family_name(family: str) -> str: """Normalize family name for consistent lookup (lowercase, remove underscores/dashes).""" return family.lower().replace('_', '').replace('-', '').replace(' ', '') MODEL_FAMILY_COLORS, MODEL_ID_COLORS = _load_model_family_colors() def _get_color_for_family(family: str) -> str: """Get color for a model family, with fallback.""" normalized = _normalize_family_name(family) return MODEL_FAMILY_COLORS.get(normalized, '#888888') def _get_color_for_model(model_id: str) -> str: """Get color for a specific model ID, with fallback to family color.""" if model_id in MODEL_ID_COLORS: return MODEL_ID_COLORS[model_id] # Try to find family and use family color for family, color in MODEL_FAMILY_COLORS.items(): if family.lower() in model_id.lower(): return color return '#888888' def _remove_wer_outliers(df: pd.DataFrame, multiplier: float = 1.5) -> pd.DataFrame: """ Remove WER outliers using IQR method for cleaner visualizations. Only removes HIGH outliers (poor performers), keeps LOW outliers (best performers). Args: df: DataFrame with 'wer' column multiplier: IQR multiplier (default 1.5 for standard outlier detection) Returns: DataFrame with high outliers removed """ if df.empty or 'wer' not in df.columns: return df Q1 = df['wer'].quantile(0.25) Q3 = df['wer'].quantile(0.75) IQR = Q3 - Q1 # Only remove HIGH outliers (poor performers), keep LOW outliers (best performers) # Lower WER is better, so we don't want to remove low-WER entries upper_bound = Q3 + multiplier * IQR return df[df['wer'] <= upper_bound] def create_language_coverage_chart(selected_languages: list[str] | None = None) -> go.Figure: """ Create a horizontal bar chart showing sample counts for each language in PazaBench. Languages are sorted by sample count (descending). Selected languages are highlighted with a different color. Args: selected_languages: List of languages to highlight (None = no highlighting) """ # Create dataframe from language sample counts df = pd.DataFrame([ {"Language": lang, "Sample Count": count} for lang, count in LANGUAGE_SAMPLE_COUNTS.items() ]) # Sort by sample count descending df = df.sort_values("Sample Count", ascending=True) # ascending=True for horizontal bar (bottom to top) # Determine colors - highlight selected languages if selected_languages and len(selected_languages) > 0: colors = [ "#0078D4" if lang in selected_languages else "#D0E8FF" for lang in df["Language"] ] # Add border to selected bars line_widths = [2 if lang in selected_languages else 0 for lang in df["Language"]] line_colors = ["#005A9E" if lang in selected_languages else "rgba(0,0,0,0)" for lang in df["Language"]] else: colors = "#8CD0FF" # Solid blue color matching theme line_widths = 0 line_colors = "rgba(0,0,0,0)" fig = go.Figure(go.Bar( y=df["Language"], x=df["Sample Count"], orientation='h', marker=dict( color=colors, line=dict(width=line_widths, color=line_colors) if selected_languages else None, ), text=None, # Remove text labels above bars textposition='none', hovertemplate='%{y}
Samples: %{x:,}', )) # Build title if selected_languages and len(selected_languages) > 0: title_text = f"Language Coverage in PazaBench ({len(selected_languages)} selected)" else: title_text = "Language Coverage in PazaBench" fig.update_layout( title=dict( text=title_text, font=dict(size=16), x=0.5 ), xaxis_title="Number of Samples", yaxis_title="", height=800, autosize=True, margin=dict(l=120, r=50, t=60, b=40), template='plotly_white', showlegend=False, ) return fig def create_language_location_map(languages: str | list[str] | None = None) -> go.Figure: """ Create an interactive choropleth map of Africa showing where specific language(s) exist. If no language is selected, shows a light overview map with all PazaBench countries highlighted. Args: languages: The language(s) to highlight on the map (single string or list) """ fig = go.Figure() # Normalize to list if isinstance(languages, str): languages = [languages] if languages and len(languages) > 0: # Get countries where these languages exist all_countries = set() country_language_map = {} # Track which languages are spoken in each country for lang in languages: if lang in LANGUAGE_COUNTRY_MAP: for code in LANGUAGE_COUNTRY_MAP[lang]: all_countries.add(code) if code not in country_language_map: country_language_map[code] = [] country_language_map[code].append(lang) if all_countries: # Create dataframe with countries that have these languages df_map = pd.DataFrame([ { "country_code": code, "country_name": COUNTRY_NAMES.get(code, code), "has_language": 1, "languages": ", ".join(country_language_map.get(code, [])) } for code in all_countries ]) # Build hover text if len(languages) == 1: hover_template = "%{text}
" + f"{languages[0]} is spoken here" else: hover_template = "%{text}
Languages: %{customdata}" fig.add_trace(go.Choropleth( locations=df_map["country_code"], z=df_map["has_language"], text=df_map["country_name"], customdata=df_map["languages"], hovertemplate=hover_template, colorscale=[[0, "#8CD0FF"], [1, "#8CD0FF"]], # Solid blue color showscale=False, marker_line_color="white", marker_line_width=0.5, )) if len(languages) == 1: title_text = f"Where {languages[0]} is Spoken" else: title_text = f"Where {len(languages)} Selected Languages are Spoken" else: title_text = "Select a Language to Explore" else: # Default view: show all countries with PazaBench data lightly highlighted all_countries = set() for countries in LANGUAGE_COUNTRY_MAP.values(): all_countries.update(countries) df_map = pd.DataFrame([ { "country_code": code, "country_name": COUNTRY_NAMES.get(code, code), "in_pazabench": 1 } for code in all_countries ]) fig.add_trace(go.Choropleth( locations=df_map["country_code"], z=df_map["in_pazabench"], text=df_map["country_name"], hovertemplate="%{text}
Has PazaBench data", colorscale=[[0, "#E8F4FF"], [1, "#E8F4FF"]], # Very light blue matching theme showscale=False, marker_line_color="#B3D9FF", marker_line_width=0.5, )) title_text = "Select a Language to Explore" fig.update_geos( visible=True, resolution=50, scope="africa", showcountries=True, countrycolor="lightgray", showcoastlines=True, coastlinecolor="gray", showland=True, landcolor="#f5f5f5", showocean=True, oceancolor="#e3f2fd", showlakes=True, lakecolor="#e3f2fd", projection_type="natural earth", center=dict(lat=5, lon=20), ) fig.update_layout( title=dict( text=title_text, font=dict(size=18), x=0.5 ), height=600, autosize=True, margin=dict(l=5, r=5, t=50, b=5), geo=dict( bgcolor="rgba(0,0,0,0)", ) ) # Use SVG renderer for better resolution fig.update_layout( template="plotly_white", ) return fig def get_language_sample_info(languages: str | list[str] | None = None, asr_df: pd.DataFrame | None = None) -> str: """ Get the sample count information for specific language(s) as styled HTML. Returns HTML for display in Gradio. Args: languages: The language(s) to get information for (single string or list) asr_df: DataFrame with ASR results to extract dataset groups """ # Normalize to list if isinstance(languages, str): languages = [languages] if languages and len(languages) > 0: # Filter to valid languages valid_languages = [lang for lang in languages if lang in LANGUAGE_SAMPLE_COUNTS] if valid_languages: # Aggregate data across all selected languages total_samples = sum(LANGUAGE_SAMPLE_COUNTS.get(lang, 0) for lang in valid_languages) all_countries = set() for lang in valid_languages: all_countries.update(LANGUAGE_COUNTRY_MAP.get(lang, [])) country_names = sorted([COUNTRY_NAMES.get(code, code) for code in all_countries]) # Get dataset groups from ASR data if available dataset_groups = set() if asr_df is not None and not asr_df.empty and 'language' in asr_df.columns: for lang in valid_languages: lang_data = asr_df[asr_df['language'] == lang] if not lang_data.empty and 'dataset_group' in lang_data.columns: dataset_groups.update(lang_data['dataset_group'].unique().tolist()) dataset_groups = sorted(dataset_groups) # Build title based on number of languages if len(valid_languages) == 1: title = f"🌍 {valid_languages[0]}" else: title = f"🌍 {', '.join(sorted(valid_languages))}" html = f"""

{title}

📊 Total Samples {total_samples:,}
📍 Countries ({len(country_names)})
{', '.join(country_names) if country_names else 'N/A'}
📁 Dataset Sources ({len(dataset_groups)})
{(', '.join(dataset_groups) if dataset_groups else 'No dataset info available')}
""" return html # Default view - show sample overview summary total_languages = len(LANGUAGE_SAMPLE_COUNTS) total_samples = sum(LANGUAGE_SAMPLE_COUNTS.values()) total_countries = len(set(code for codes in LANGUAGE_COUNTRY_MAP.values() for code in codes)) html = f"""

📊 Sample Overview

🌍 Languages {total_languages}
📊 Total Samples {total_samples:,}
📍 Countries {total_countries}

👈 Select a language on the left to explore its details

""" return html def get_language_sample_info_df(language: str | None = None) -> pd.DataFrame: """ Legacy function - returns DataFrame for backward compatibility. """ if language and language in LANGUAGE_SAMPLE_COUNTS: sample_count = LANGUAGE_SAMPLE_COUNTS[language] countries = LANGUAGE_COUNTRY_MAP.get(language, []) country_names = [COUNTRY_NAMES.get(code, code) for code in countries] df = pd.DataFrame({ "Metric": ["Language", "Total Samples", "Countries"], "Value": [ language, f"{sample_count:,}", ", ".join(country_names) if country_names else "N/A" ] }) else: df = pd.DataFrame({ "Metric": ["Language", "Total Samples", "Countries"], "Value": ["Select a language", "-", "-"] }) return df def get_all_languages() -> list[str]: """Get a sorted list of all languages in PazaBench.""" return sorted(LANGUAGE_SAMPLE_COUNTS.keys()) def create_africa_language_map() -> go.Figure: """ Create an interactive choropleth map of Africa showing language coverage. Hover over countries to see the languages spoken there. """ # Build country data with languages country_data = {} for language, countries in LANGUAGE_COUNTRY_MAP.items(): for country_code in countries: if country_code not in country_data: country_data[country_code] = { "languages": [], "count": 0, "country_name": COUNTRY_NAMES.get(country_code, country_code) } country_data[country_code]["languages"].append(language) country_data[country_code]["count"] += 1 # Create dataframe for plotly df_map = pd.DataFrame([ { "country_code": code, "country_name": data["country_name"], "language_count": data["count"], "languages": ", ".join(sorted(data["languages"])) } for code, data in country_data.items() ]) fig = go.Figure(go.Choropleth( locations=df_map["country_code"], z=df_map["language_count"], text=df_map["country_name"], customdata=df_map[["languages", "language_count"]], hovertemplate="%{text}
" + "Languages: %{customdata[1]}
" + "%{customdata[0]}", colorscale=[ [0, "#DDF1FF"], [0.25, "#B5E0FF"], [0.5, "#8CD0FF"], [0.75, "#6BC0F5"], [1, "#4AAFEB"] ], showscale=True, colorbar=dict( title="Languages", tickmode="linear", tick0=1, dtick=1 ), marker_line_color="white", marker_line_width=0.5, )) fig.update_geos( visible=True, resolution=50, scope="africa", showcountries=True, countrycolor="lightgray", showcoastlines=True, coastlinecolor="gray", showland=True, landcolor="#f5f5f5", showocean=True, oceancolor="#e3f2fd", showlakes=True, lakecolor="#e3f2fd", projection_type="natural earth", center=dict(lat=5, lon=20), ) fig.update_layout( title=dict( text="African Languages in PazaBench", font=dict(size=18), x=0.5 ), height=600, autosize=True, margin=dict(l=5, r=5, t=50, b=5), geo=dict( bgcolor="rgba(0,0,0,0)", ), template="plotly_white", ) return fig def create_model_leaderboard(df: pd.DataFrame, languages: list[str] | None = None, top_n_models: int = 15) -> go.Figure: """ Visualization 1: Model Family / Individual Model Performance Leaderboard - When no language filter: Shows model families (aggregated) - When language(s) selected: Shows top N individual models Outliers are removed for cleaner visualization. Args: df: DataFrame with evaluation results languages: List of languages to filter by (None = all languages) top_n_models: Number of top individual models to show when languages are filtered (default: 15) """ # Apply language filter if provided filtered_df = df.copy() if languages: filtered_df = filtered_df[filtered_df['language'].isin(languages)] # Remove WER outliers for cleaner visualization filtered_df = _remove_wer_outliers(filtered_df) # Determine mode: individual models if languages selected, otherwise model families show_individual_models = languages is not None and len(languages) > 0 if show_individual_models: # Individual model mode: show top N models by median WER model_perf = filtered_df.groupby(['model_family', 'model']).agg({ 'wer': ['median', 'std', 'count'], 'cer': 'median', 'rtfx': 'median' }).reset_index() # Get unique sample counts per model unique_samples = filtered_df.groupby(['model_family', 'model', 'language', 'dataset_group'])['num_samples'].first().reset_index() model_samples = unique_samples.groupby(['model_family', 'model'])['num_samples'].sum().reset_index() model_samples.columns = ['model_family', 'model', 'total_samples'] model_perf.columns = ['model_family', 'model', 'wer_median', 'wer_std', 'count', 'cer_median', 'rtfx_median'] model_perf = model_perf.merge(model_samples, on=['model_family', 'model'], how='left') model_perf = model_perf.sort_values('wer_median').head(top_n_models) # Create short model name for display model_perf['model_short'] = model_perf['model'].apply( lambda x: x.split('/')[-1] if '/' in x else x ) # Get colors for each model based on family model_perf['color'] = model_perf['model_family'].apply(_get_color_for_family) fig = go.Figure() fig.add_trace(go.Bar( y=model_perf['model_short'], x=model_perf['wer_median'], orientation='h', marker=dict(color=model_perf['color']), text=model_perf['wer_median'].round(2), textposition='outside', hovertemplate=( '%{customdata[0]}
' + 'Family: %{customdata[1]}

' + 'Median WER: %{x:.3f}
' + 'RTFx: %{customdata[2]:.1f}
' + 'Evaluations: %{customdata[3]}
' + 'Samples: %{customdata[4]:,}' ), customdata=model_perf[['model', 'model_family', 'rtfx_median', 'count', 'total_samples']] )) # Build title with language info (languages is guaranteed to be non-empty here) lang_list = languages if languages else [] lang_str = ", ".join(lang_list[:3]) + ("..." if len(lang_list) > 3 else "") title_text = f"Top {min(top_n_models, len(model_perf))} Models for {lang_str}" else: # Model family mode (original behavior) model_perf = filtered_df.groupby('model_family').agg({ 'wer': ['median', 'std', 'count'], 'cer': 'median', 'rtfx': 'median' }).reset_index() # Get unique sample counts per model family (avoid double-counting across models) unique_samples = filtered_df.groupby(['model_family', 'language', 'dataset_group'])['num_samples'].first().reset_index() family_samples = unique_samples.groupby('model_family')['num_samples'].sum().reset_index() family_samples.columns = ['model_family', 'total_samples'] model_perf.columns = ['model_family', 'wer_median', 'wer_std', 'count', 'cer_median', 'rtfx_median'] model_perf = model_perf.merge(family_samples, on='model_family', how='left') model_perf = model_perf.sort_values('wer_median') # Get colors for each model family model_perf['color'] = model_perf['model_family'].apply(_get_color_for_family) fig = go.Figure() fig.add_trace(go.Bar( y=model_perf['model_family'], x=model_perf['wer_median'], orientation='h', error_x=dict(type='data', array=model_perf['wer_std']), marker=dict(color=model_perf['color']), text=model_perf['wer_median'].round(2), textposition='outside', hovertemplate=( '%{y}

' + 'Median WER: %{x:.3f}
' + 'Std Dev: %{customdata[0]:.3f}
' + 'RTFx: %{customdata[1]:.1f}
' + 'Evaluations: %{customdata[2]}
' + 'Samples: %{customdata[3]:,}' ), customdata=model_perf[['wer_std', 'rtfx_median', 'count', 'total_samples']] )) title_text = "Model Family Performance Leaderboard" # Calculate dynamic height based on number of items num_items = len(model_perf) height = max(400, min(700, 100 + num_items * 35)) fig.update_layout( title=title_text, xaxis_title="Word Error Rate (WER)", yaxis_title="", height=height, autosize=True, showlegend=False, template='plotly_white', margin=dict(l=200, r=30, t=60, b=60) ) return fig def create_cer_leaderboard(df: pd.DataFrame, languages: list[str] | None = None, top_n_models: int = 15) -> go.Figure: """ Visualization: CER Model Family / Individual Model Performance Leaderboard - When no language filter: Shows model families (aggregated) - When language(s) selected: Shows top N individual models Outliers are removed for cleaner visualization. Args: df: DataFrame with evaluation results languages: List of languages to filter by (None = all languages) top_n_models: Number of top individual models to show when languages are filtered (default: 15) """ # Apply language filter if provided filtered_df = df.copy() if languages: filtered_df = filtered_df[filtered_df['language'].isin(languages)] # Remove CER outliers for cleaner visualization (similar to WER outlier removal) if not filtered_df.empty and 'cer' in filtered_df.columns: Q1 = filtered_df['cer'].quantile(0.25) Q3 = filtered_df['cer'].quantile(0.75) IQR = Q3 - Q1 upper_bound = Q3 + 1.5 * IQR filtered_df = filtered_df[filtered_df['cer'] <= upper_bound] # Determine mode: individual models if languages selected, otherwise model families show_individual_models = languages is not None and len(languages) > 0 if show_individual_models: # Individual model mode: show top N models by median CER model_perf = filtered_df.groupby(['model_family', 'model']).agg({ 'cer': ['median', 'std', 'count'], 'wer': 'median', 'rtfx': 'median' }).reset_index() # Get unique sample counts per model unique_samples = filtered_df.groupby(['model_family', 'model', 'language', 'dataset_group'])['num_samples'].first().reset_index() model_samples = unique_samples.groupby(['model_family', 'model'])['num_samples'].sum().reset_index() model_samples.columns = ['model_family', 'model', 'total_samples'] model_perf.columns = ['model_family', 'model', 'cer_median', 'cer_std', 'count', 'wer_median', 'rtfx_median'] model_perf = model_perf.merge(model_samples, on=['model_family', 'model'], how='left') model_perf = model_perf.sort_values('cer_median').head(top_n_models) # Create short model name for display model_perf['model_short'] = model_perf['model'].apply( lambda x: x.split('/')[-1] if '/' in x else x ) # Get colors for each model based on family model_perf['color'] = model_perf['model_family'].apply(_get_color_for_family) fig = go.Figure() fig.add_trace(go.Bar( y=model_perf['model_short'], x=model_perf['cer_median'], orientation='h', marker=dict(color=model_perf['color']), text=model_perf['cer_median'].round(2), textposition='outside', hovertemplate=( '%{customdata[0]}
' + 'Family: %{customdata[1]}

' + 'Median CER: %{x:.3f}
' + 'WER: %{customdata[2]:.3f}
' + 'RTFx: %{customdata[3]:.1f}
' + 'Evaluations: %{customdata[4]}
' + 'Samples: %{customdata[5]:,}' ), customdata=model_perf[['model', 'model_family', 'wer_median', 'rtfx_median', 'count', 'total_samples']] )) # Build title with language info (languages is guaranteed to be non-empty here) lang_list = languages if languages else [] lang_str = ", ".join(lang_list[:3]) + ("..." if len(lang_list) > 3 else "") title_text = f"Top {min(top_n_models, len(model_perf))} Models by CER for {lang_str}" else: # Model family mode (original behavior) model_perf = filtered_df.groupby('model_family').agg({ 'cer': ['median', 'std', 'count'], 'wer': 'median', 'rtfx': 'median' }).reset_index() # Get unique sample counts per model family (avoid double-counting across models) unique_samples = filtered_df.groupby(['model_family', 'language', 'dataset_group'])['num_samples'].first().reset_index() family_samples = unique_samples.groupby('model_family')['num_samples'].sum().reset_index() family_samples.columns = ['model_family', 'total_samples'] model_perf.columns = ['model_family', 'cer_median', 'cer_std', 'count', 'wer_median', 'rtfx_median'] model_perf = model_perf.merge(family_samples, on='model_family', how='left') model_perf = model_perf.sort_values('cer_median') # Get colors for each model family model_perf['color'] = model_perf['model_family'].apply(_get_color_for_family) fig = go.Figure() fig.add_trace(go.Bar( y=model_perf['model_family'], x=model_perf['cer_median'], orientation='h', error_x=dict(type='data', array=model_perf['cer_std']), marker=dict(color=model_perf['color']), text=model_perf['cer_median'].round(2), textposition='outside', hovertemplate=( '%{y}

' + 'Median CER: %{x:.3f}
' + 'Std Dev: %{customdata[0]:.3f}
' + 'WER: %{customdata[1]:.3f}
' + 'RTFx: %{customdata[2]:.1f}
' + 'Evaluations: %{customdata[3]}
' + 'Samples: %{customdata[4]:,}' ), customdata=model_perf[['cer_std', 'wer_median', 'rtfx_median', 'count', 'total_samples']] )) title_text = "Model Family Performance by CER" # Calculate dynamic height based on number of items num_items = len(model_perf) height = max(400, min(700, 100 + num_items * 35)) fig.update_layout( title=title_text, xaxis_title="Character Error Rate (CER)", yaxis_title="", height=height, autosize=True, showlegend=False, template='plotly_white', margin=dict(l=200, r=30, t=60, b=60) ) return fig def create_speed_accuracy_scatter(df: pd.DataFrame, view_mode: str = "model_family", languages: list[str] | None = None) -> go.Figure: """ Visualization 2: Speed vs Accuracy Tradeoff Scatter plot showing the relationship between WER and RTFx with quadrants. Outliers are removed for cleaner visualization. Args: df: DataFrame with evaluation results view_mode: Either "model_family" (bubbles same size per family, color by family) or "individual_model" (bubble size = model params, color = model family) languages: List of languages to filter by (None = all languages) """ # Apply language filter if provided if languages: df = df[df['language'].isin(languages)] # Remove WER outliers for cleaner visualization df = _remove_wer_outliers(df) if view_mode == "individual_model": # Aggregate by individual model model_agg = df.groupby(['model_family', 'model']).agg({ 'wer': 'median', 'rtfx': 'median', 'cer': 'median', }).reset_index() # Get unique sample counts per model unique_samples = df.groupby(['model_family', 'model', 'language', 'dataset_group'])['num_samples'].first().reset_index() model_samples = unique_samples.groupby(['model_family', 'model'])['num_samples'].sum().reset_index() model_samples.columns = ['model_family', 'model', 'num_samples'] model_agg = model_agg.merge(model_samples, on=['model_family', 'model'], how='left') # Get parameter count for each individual model model_agg['params'] = model_agg['model'].apply( lambda m: MODEL_PARAMETER_COUNTS.get(m, 500_000_000) # Default 500M ) model_agg['params_billions'] = model_agg['params'] / 1_000_000_000 model_agg['params_display'] = model_agg['params'].apply( lambda x: f"{x/1_000_000_000:.1f}B" if x >= 1_000_000_000 else f"{x/1_000_000:.0f}M" ) # Create short model name for display model_agg['model_short'] = model_agg['model'].apply( lambda x: x.split('/')[-1] if '/' in x else x ) fig = go.Figure() # Add scatter traces for each model family for family in model_agg['model_family'].unique(): family_data = model_agg[model_agg['model_family'] == family] family_color = _get_color_for_family(family) fig.add_trace(go.Scatter( x=family_data['wer'], y=family_data['rtfx'], mode='markers', name=family, marker=dict( size=family_data['params'] / family_data['params'].max() * 50 + 10, sizemode='diameter', sizemin=8, color=family_color ), customdata=family_data[['model_short', 'cer', 'num_samples', 'params_display']].values, hovertemplate=( '%{customdata[0]}

' + 'WER: %{x:.3f}
' + 'RTFx: %{y:.1f}
' + 'CER: %{customdata[1]:.3f}
' + 'Samples: %{customdata[2]:,}
' + 'Parameters: %{customdata[3]}' ) )) title_text = "Speed vs Accuracy Tradeoff by Individual Model" else: # Original behavior: aggregate by model family model_agg = df.groupby('model_family').agg({ 'wer': 'median', 'rtfx': 'median', 'cer': 'median', 'model': 'first' # Get a representative model name for parameter lookup }).reset_index() # Get unique sample counts per model family unique_samples = df.groupby(['model_family', 'language', 'dataset_group'])['num_samples'].first().reset_index() family_samples = unique_samples.groupby('model_family')['num_samples'].sum().reset_index() family_samples.columns = ['model_family', 'num_samples'] model_agg = model_agg.merge(family_samples, on='model_family', how='left') # For model family view, use a uniform size (no bubble size variation) # Use a constant for params to make bubbles the same size model_agg['params'] = 1_000_000_000 # Use constant 1B for uniform bubble size model_agg['params_display'] = 'N/A' # Not applicable in family view fig = go.Figure() # Add scatter traces for each model family for family in model_agg['model_family'].unique(): family_data = model_agg[model_agg['model_family'] == family] family_color = _get_color_for_family(family) fig.add_trace(go.Scatter( x=family_data['wer'], y=family_data['rtfx'], mode='markers', name=family, marker=dict( size=20, sizemode='diameter', color=family_color ), customdata=family_data[['cer', 'num_samples']].values, hovertemplate=( '' + family + '

' + 'WER: %{x:.3f}
' + 'RTFx: %{y:.1f}
' + 'CER: %{customdata[0]:.3f}
' + 'Samples: %{customdata[1]:,}' ) )) title_text = "Speed vs Accuracy Tradeoff by Model Family" # Add quadrant lines at median median_wer = model_agg['wer'].median() median_rtfx = model_agg['rtfx'].median() fig.add_hline(y=median_rtfx, line_dash="dash", line_color="gray", annotation_text="Median RTFx", annotation_position="right") fig.add_vline(x=median_wer, line_dash="dash", line_color="gray", annotation_text="Median WER", annotation_position="top") # Add quadrant label - centered in the "Fast & Accurate" quadrant (low WER, high RTFx) # The ideal quadrant is: x from min to median_wer, y from median_rtfx to max quadrant_center_x = (model_agg['wer'].min() + median_wer) / 2 quadrant_center_y = (median_rtfx + model_agg['rtfx'].max()) / 2 fig.add_annotation( x=quadrant_center_x, y=quadrant_center_y, text="Fast & Accurate ⭐", showarrow=False, font=dict(size=12, color="green", family="Arial Black") ) fig.update_layout( title=title_text, xaxis_title="WER", yaxis_title="RTFx", yaxis=dict(rangemode='tozero'), # Ensure y-axis starts at 0 (RTFx can't be negative) height=550, autosize=True, template='plotly_white', legend=dict( orientation="h", yanchor="top", y=-0.15, xanchor="center", x=0.5, font=dict(size=10) ), margin=dict(l=60, r=20, t=50, b=120) ) return fig def create_wer_cer_correlation(df: pd.DataFrame, languages: list[str] | None = None, top_n_models: int | None = None) -> go.Figure: """ Visualization 7: WER vs CER Correlation Scatter plot showing the relationship between word and character error rates. Defaults to Swahili if no language is specified. Outliers are removed for cleaner visualization. Args: df: DataFrame with evaluation results languages: List of languages to filter by (defaults to ["Swahili"] if None) top_n_models: If specified, only show top N models by WER (0 = show all) """ # Default to Swahili if no language filter provided if not languages: languages = ["Swahili"] # Apply language filter filtered_df = df.copy() filtered_df = filtered_df[filtered_df['language'].isin(languages)] # Remove WER outliers for cleaner visualization filtered_df = _remove_wer_outliers(filtered_df) # Apply top N models filter if specified if top_n_models and top_n_models > 0: model_wer = filtered_df.groupby('model')['wer'].median().sort_values() top_models = model_wer.head(top_n_models).index.tolist() filtered_df = filtered_df[filtered_df['model'].isin(top_models)] fig = go.Figure() # Add scatter traces for each model family for family in filtered_df['model_family'].unique(): family_data = filtered_df[filtered_df['model_family'] == family] family_color = _get_color_for_family(family) # Normalize size for better visualization max_samples = filtered_df['num_samples'].max() if not filtered_df.empty else 1 sizes = (family_data['num_samples'] / max_samples * 25 + 5).values if not family_data.empty else [10] fig.add_trace(go.Scatter( x=family_data['wer'], y=family_data['cer'], mode='markers', name=family, marker=dict( size=sizes, sizemode='diameter', sizemin=5, opacity=0.6, color=family_color ), customdata=family_data[['language', 'model', 'dataset_group', 'num_samples']].values, hovertemplate=( '%{customdata[0]}

' + 'Model: %{customdata[1]}
' + 'Dataset: %{customdata[2]}
' + 'WER: %{x:.3f}
' + 'CER: %{y:.3f}
' + 'Samples: %{customdata[3]:,}' ) )) # Add trendline if len(filtered_df) > 1: import numpy as np z = np.polyfit(filtered_df['wer'], filtered_df['cer'], 1) p = np.poly1d(z) x_range = np.linspace(filtered_df['wer'].min(), filtered_df['wer'].max(), 100) fig.add_trace(go.Scatter( x=x_range, y=p(x_range), mode='lines', name='Trend', line=dict(color='gray', dash='dash'), hoverinfo='skip' )) # Calculate correlation correlation = filtered_df[['wer', 'cer']].corr().iloc[0, 1] if len(filtered_df) > 1 else 0 # Build title with language info lang_str = ", ".join(languages[:3]) + ("..." if len(languages) > 3 else "") title_text = f"WER vs CER Correlation for {lang_str} (r={correlation:.2f})" fig.update_layout( title=title_text, xaxis_title="WER", yaxis_title="CER", height=550, autosize=True, template='plotly_white', legend=dict( orientation="h", yanchor="top", y=-0.15, xanchor="center", x=0.5, font=dict(size=10) ), margin=dict(l=60, r=20, t=50, b=120), # Zoom X-axis to useful range (0-1.5) xaxis=dict(range=[0, 1.5]), yaxis=dict(range=[0, 1.5]) ) return fig def create_model_consistency(df: pd.DataFrame) -> go.Figure: """ Visualization 9: Model Consistency Analysis Shows coefficient of variation (CV) to measure consistency across languages. Removes high outliers only using IQR method (keeps best performers). """ # Remove only HIGH outliers using IQR method on WER (keep best performers) Q1 = df['wer'].quantile(0.25) Q3 = df['wer'].quantile(0.75) IQR = Q3 - Q1 upper_bound = Q3 + 1.5 * IQR df_no_outliers = df[df['wer'] <= upper_bound] model_variance = df_no_outliers.groupby('model_family').agg({ 'wer': ['median', 'std', 'count'] }).reset_index() model_variance.columns = ['model_family', 'wer_median', 'wer_std', 'count'] model_variance['cv'] = (model_variance['wer_std'] / model_variance['wer_median'] * 100) model_variance = model_variance.sort_values('cv') # Get colors for each model family model_variance['color'] = model_variance['model_family'].apply(_get_color_for_family) fig = go.Figure() fig.add_trace(go.Bar( y=model_variance['model_family'], x=model_variance['cv'], orientation='h', marker=dict( color=model_variance['color'] ), text=model_variance['cv'].round(1), textposition='outside', hovertemplate=( '%{y}

' + 'Coefficient of Variation: %{x:.1f}%
' + 'Median WER: %{customdata[0]:.3f}
' + 'Std Dev: %{customdata[1]:.3f}
' + 'Evaluations: %{customdata[2]}' ), customdata=model_variance[['wer_median', 'wer_std', 'count']] )) fig.update_layout( title="Model Consistency Ranking (Outliers Removed)", xaxis_title="Coefficient of Variation (%)", yaxis_title="Model Family", height=550, template='plotly_white', showlegend=False, margin=dict(l=200, r=100, t=80, b=80) ) return fig