"""
PazaBench Visualization Functions for Gradio Integration
"""
import os
import pandas as pd
import plotly.graph_objects as go
from src.constants import (
COUNTRY_NAMES,
LANGUAGE_COUNTRY_MAP,
LANGUAGE_SAMPLE_COUNTS,
)
from src.model_counts import MODEL_PARAMETER_COUNTS
# Load model family colors from CSV
def _load_model_family_colors() -> tuple[dict[str, str], dict[str, str]]:
"""
Load color mappings from the model_family_colors.csv file.
Returns:
- model_family_colors: dict mapping model_family -> color (first color for that family)
- model_id_colors: dict mapping model_id -> color
"""
csv_path = os.path.join(os.path.dirname(__file__), 'display', 'model_family_colors.csv')
model_family_colors = {}
model_id_colors = {}
try:
df = pd.read_csv(csv_path)
for _, row in df.iterrows():
family = row['Model Families']
model_id = row['Model ID']
color = row['Color']
# Store first color encountered for each family (using normalized key)
normalized_family = _normalize_family_name(family)
if normalized_family not in model_family_colors:
model_family_colors[normalized_family] = color
# Store color for each model ID
model_id_colors[model_id] = color
except Exception as e:
print(f"Warning: Could not load model family colors: {e}")
return model_family_colors, model_id_colors
def _normalize_family_name(family: str) -> str:
"""Normalize family name for consistent lookup (lowercase, remove underscores/dashes)."""
return family.lower().replace('_', '').replace('-', '').replace(' ', '')
MODEL_FAMILY_COLORS, MODEL_ID_COLORS = _load_model_family_colors()
def _get_color_for_family(family: str) -> str:
"""Get color for a model family, with fallback."""
normalized = _normalize_family_name(family)
return MODEL_FAMILY_COLORS.get(normalized, '#888888')
def _get_color_for_model(model_id: str) -> str:
"""Get color for a specific model ID, with fallback to family color."""
if model_id in MODEL_ID_COLORS:
return MODEL_ID_COLORS[model_id]
# Try to find family and use family color
for family, color in MODEL_FAMILY_COLORS.items():
if family.lower() in model_id.lower():
return color
return '#888888'
def _remove_wer_outliers(df: pd.DataFrame, multiplier: float = 1.5) -> pd.DataFrame:
"""
Remove WER outliers using IQR method for cleaner visualizations.
Only removes HIGH outliers (poor performers), keeps LOW outliers (best performers).
Args:
df: DataFrame with 'wer' column
multiplier: IQR multiplier (default 1.5 for standard outlier detection)
Returns:
DataFrame with high outliers removed
"""
if df.empty or 'wer' not in df.columns:
return df
Q1 = df['wer'].quantile(0.25)
Q3 = df['wer'].quantile(0.75)
IQR = Q3 - Q1
# Only remove HIGH outliers (poor performers), keep LOW outliers (best performers)
# Lower WER is better, so we don't want to remove low-WER entries
upper_bound = Q3 + multiplier * IQR
return df[df['wer'] <= upper_bound]
def create_language_coverage_chart(selected_languages: list[str] | None = None) -> go.Figure:
"""
Create a horizontal bar chart showing sample counts for each language in PazaBench.
Languages are sorted by sample count (descending).
Selected languages are highlighted with a different color.
Args:
selected_languages: List of languages to highlight (None = no highlighting)
"""
# Create dataframe from language sample counts
df = pd.DataFrame([
{"Language": lang, "Sample Count": count}
for lang, count in LANGUAGE_SAMPLE_COUNTS.items()
])
# Sort by sample count descending
df = df.sort_values("Sample Count", ascending=True) # ascending=True for horizontal bar (bottom to top)
# Determine colors - highlight selected languages
if selected_languages and len(selected_languages) > 0:
colors = [
"#0078D4" if lang in selected_languages else "#D0E8FF"
for lang in df["Language"]
]
# Add border to selected bars
line_widths = [2 if lang in selected_languages else 0 for lang in df["Language"]]
line_colors = ["#005A9E" if lang in selected_languages else "rgba(0,0,0,0)" for lang in df["Language"]]
else:
colors = "#8CD0FF" # Solid blue color matching theme
line_widths = 0
line_colors = "rgba(0,0,0,0)"
fig = go.Figure(go.Bar(
y=df["Language"],
x=df["Sample Count"],
orientation='h',
marker=dict(
color=colors,
line=dict(width=line_widths, color=line_colors) if selected_languages else None,
),
text=None, # Remove text labels above bars
textposition='none',
hovertemplate='%{y}
Samples: %{x:,}',
))
# Build title
if selected_languages and len(selected_languages) > 0:
title_text = f"Language Coverage in PazaBench ({len(selected_languages)} selected)"
else:
title_text = "Language Coverage in PazaBench"
fig.update_layout(
title=dict(
text=title_text,
font=dict(size=16),
x=0.5
),
xaxis_title="Number of Samples",
yaxis_title="",
height=800,
autosize=True,
margin=dict(l=120, r=50, t=60, b=40),
template='plotly_white',
showlegend=False,
)
return fig
def create_language_location_map(languages: str | list[str] | None = None) -> go.Figure:
"""
Create an interactive choropleth map of Africa showing where specific language(s) exist.
If no language is selected, shows a light overview map with all PazaBench countries highlighted.
Args:
languages: The language(s) to highlight on the map (single string or list)
"""
fig = go.Figure()
# Normalize to list
if isinstance(languages, str):
languages = [languages]
if languages and len(languages) > 0:
# Get countries where these languages exist
all_countries = set()
country_language_map = {} # Track which languages are spoken in each country
for lang in languages:
if lang in LANGUAGE_COUNTRY_MAP:
for code in LANGUAGE_COUNTRY_MAP[lang]:
all_countries.add(code)
if code not in country_language_map:
country_language_map[code] = []
country_language_map[code].append(lang)
if all_countries:
# Create dataframe with countries that have these languages
df_map = pd.DataFrame([
{
"country_code": code,
"country_name": COUNTRY_NAMES.get(code, code),
"has_language": 1,
"languages": ", ".join(country_language_map.get(code, []))
}
for code in all_countries
])
# Build hover text
if len(languages) == 1:
hover_template = "%{text}
" + f"{languages[0]} is spoken here"
else:
hover_template = "%{text}
Languages: %{customdata}"
fig.add_trace(go.Choropleth(
locations=df_map["country_code"],
z=df_map["has_language"],
text=df_map["country_name"],
customdata=df_map["languages"],
hovertemplate=hover_template,
colorscale=[[0, "#8CD0FF"], [1, "#8CD0FF"]], # Solid blue color
showscale=False,
marker_line_color="white",
marker_line_width=0.5,
))
if len(languages) == 1:
title_text = f"Where {languages[0]} is Spoken"
else:
title_text = f"Where {len(languages)} Selected Languages are Spoken"
else:
title_text = "Select a Language to Explore"
else:
# Default view: show all countries with PazaBench data lightly highlighted
all_countries = set()
for countries in LANGUAGE_COUNTRY_MAP.values():
all_countries.update(countries)
df_map = pd.DataFrame([
{
"country_code": code,
"country_name": COUNTRY_NAMES.get(code, code),
"in_pazabench": 1
}
for code in all_countries
])
fig.add_trace(go.Choropleth(
locations=df_map["country_code"],
z=df_map["in_pazabench"],
text=df_map["country_name"],
hovertemplate="%{text}
Has PazaBench data",
colorscale=[[0, "#E8F4FF"], [1, "#E8F4FF"]], # Very light blue matching theme
showscale=False,
marker_line_color="#B3D9FF",
marker_line_width=0.5,
))
title_text = "Select a Language to Explore"
fig.update_geos(
visible=True,
resolution=50,
scope="africa",
showcountries=True,
countrycolor="lightgray",
showcoastlines=True,
coastlinecolor="gray",
showland=True,
landcolor="#f5f5f5",
showocean=True,
oceancolor="#e3f2fd",
showlakes=True,
lakecolor="#e3f2fd",
projection_type="natural earth",
center=dict(lat=5, lon=20),
)
fig.update_layout(
title=dict(
text=title_text,
font=dict(size=18),
x=0.5
),
height=600,
autosize=True,
margin=dict(l=5, r=5, t=50, b=5),
geo=dict(
bgcolor="rgba(0,0,0,0)",
)
)
# Use SVG renderer for better resolution
fig.update_layout(
template="plotly_white",
)
return fig
def get_language_sample_info(languages: str | list[str] | None = None, asr_df: pd.DataFrame | None = None) -> str:
"""
Get the sample count information for specific language(s) as styled HTML.
Returns HTML for display in Gradio.
Args:
languages: The language(s) to get information for (single string or list)
asr_df: DataFrame with ASR results to extract dataset groups
"""
# Normalize to list
if isinstance(languages, str):
languages = [languages]
if languages and len(languages) > 0:
# Filter to valid languages
valid_languages = [lang for lang in languages if lang in LANGUAGE_SAMPLE_COUNTS]
if valid_languages:
# Aggregate data across all selected languages
total_samples = sum(LANGUAGE_SAMPLE_COUNTS.get(lang, 0) for lang in valid_languages)
all_countries = set()
for lang in valid_languages:
all_countries.update(LANGUAGE_COUNTRY_MAP.get(lang, []))
country_names = sorted([COUNTRY_NAMES.get(code, code) for code in all_countries])
# Get dataset groups from ASR data if available
dataset_groups = set()
if asr_df is not None and not asr_df.empty and 'language' in asr_df.columns:
for lang in valid_languages:
lang_data = asr_df[asr_df['language'] == lang]
if not lang_data.empty and 'dataset_group' in lang_data.columns:
dataset_groups.update(lang_data['dataset_group'].unique().tolist())
dataset_groups = sorted(dataset_groups)
# Build title based on number of languages
if len(valid_languages) == 1:
title = f"🌍 {valid_languages[0]}"
else:
title = f"🌍 {', '.join(sorted(valid_languages))}"
html = f"""
{title}
📊 Total Samples
{total_samples:,}
📍 Countries ({len(country_names)})
{', '.join(country_names) if country_names else 'N/A'}
📁 Dataset Sources ({len(dataset_groups)})
{(', '.join(dataset_groups) if dataset_groups else 'No dataset info available')}
"""
return html
# Default view - show sample overview summary
total_languages = len(LANGUAGE_SAMPLE_COUNTS)
total_samples = sum(LANGUAGE_SAMPLE_COUNTS.values())
total_countries = len(set(code for codes in LANGUAGE_COUNTRY_MAP.values() for code in codes))
html = f"""
📊 Sample Overview
🌍 Languages
{total_languages}
📊 Total Samples
{total_samples:,}
📍 Countries
{total_countries}
👈 Select a language on the left to explore its details
"""
return html
def get_language_sample_info_df(language: str | None = None) -> pd.DataFrame:
"""
Legacy function - returns DataFrame for backward compatibility.
"""
if language and language in LANGUAGE_SAMPLE_COUNTS:
sample_count = LANGUAGE_SAMPLE_COUNTS[language]
countries = LANGUAGE_COUNTRY_MAP.get(language, [])
country_names = [COUNTRY_NAMES.get(code, code) for code in countries]
df = pd.DataFrame({
"Metric": ["Language", "Total Samples", "Countries"],
"Value": [
language,
f"{sample_count:,}",
", ".join(country_names) if country_names else "N/A"
]
})
else:
df = pd.DataFrame({
"Metric": ["Language", "Total Samples", "Countries"],
"Value": ["Select a language", "-", "-"]
})
return df
def get_all_languages() -> list[str]:
"""Get a sorted list of all languages in PazaBench."""
return sorted(LANGUAGE_SAMPLE_COUNTS.keys())
def create_africa_language_map() -> go.Figure:
"""
Create an interactive choropleth map of Africa showing language coverage.
Hover over countries to see the languages spoken there.
"""
# Build country data with languages
country_data = {}
for language, countries in LANGUAGE_COUNTRY_MAP.items():
for country_code in countries:
if country_code not in country_data:
country_data[country_code] = {
"languages": [],
"count": 0,
"country_name": COUNTRY_NAMES.get(country_code, country_code)
}
country_data[country_code]["languages"].append(language)
country_data[country_code]["count"] += 1
# Create dataframe for plotly
df_map = pd.DataFrame([
{
"country_code": code,
"country_name": data["country_name"],
"language_count": data["count"],
"languages": ", ".join(sorted(data["languages"]))
}
for code, data in country_data.items()
])
fig = go.Figure(go.Choropleth(
locations=df_map["country_code"],
z=df_map["language_count"],
text=df_map["country_name"],
customdata=df_map[["languages", "language_count"]],
hovertemplate="%{text}
" +
"Languages: %{customdata[1]}
" +
"%{customdata[0]}",
colorscale=[
[0, "#DDF1FF"],
[0.25, "#B5E0FF"],
[0.5, "#8CD0FF"],
[0.75, "#6BC0F5"],
[1, "#4AAFEB"]
],
showscale=True,
colorbar=dict(
title="Languages",
tickmode="linear",
tick0=1,
dtick=1
),
marker_line_color="white",
marker_line_width=0.5,
))
fig.update_geos(
visible=True,
resolution=50,
scope="africa",
showcountries=True,
countrycolor="lightgray",
showcoastlines=True,
coastlinecolor="gray",
showland=True,
landcolor="#f5f5f5",
showocean=True,
oceancolor="#e3f2fd",
showlakes=True,
lakecolor="#e3f2fd",
projection_type="natural earth",
center=dict(lat=5, lon=20),
)
fig.update_layout(
title=dict(
text="African Languages in PazaBench",
font=dict(size=18),
x=0.5
),
height=600,
autosize=True,
margin=dict(l=5, r=5, t=50, b=5),
geo=dict(
bgcolor="rgba(0,0,0,0)",
),
template="plotly_white",
)
return fig
def create_model_leaderboard(df: pd.DataFrame, languages: list[str] | None = None, top_n_models: int = 15) -> go.Figure:
"""
Visualization 1: Model Family / Individual Model Performance Leaderboard
- When no language filter: Shows model families (aggregated)
- When language(s) selected: Shows top N individual models
Outliers are removed for cleaner visualization.
Args:
df: DataFrame with evaluation results
languages: List of languages to filter by (None = all languages)
top_n_models: Number of top individual models to show when languages are filtered (default: 15)
"""
# Apply language filter if provided
filtered_df = df.copy()
if languages:
filtered_df = filtered_df[filtered_df['language'].isin(languages)]
# Remove WER outliers for cleaner visualization
filtered_df = _remove_wer_outliers(filtered_df)
# Determine mode: individual models if languages selected, otherwise model families
show_individual_models = languages is not None and len(languages) > 0
if show_individual_models:
# Individual model mode: show top N models by median WER
model_perf = filtered_df.groupby(['model_family', 'model']).agg({
'wer': ['median', 'std', 'count'],
'cer': 'median',
'rtfx': 'median'
}).reset_index()
# Get unique sample counts per model
unique_samples = filtered_df.groupby(['model_family', 'model', 'language', 'dataset_group'])['num_samples'].first().reset_index()
model_samples = unique_samples.groupby(['model_family', 'model'])['num_samples'].sum().reset_index()
model_samples.columns = ['model_family', 'model', 'total_samples']
model_perf.columns = ['model_family', 'model', 'wer_median', 'wer_std', 'count', 'cer_median', 'rtfx_median']
model_perf = model_perf.merge(model_samples, on=['model_family', 'model'], how='left')
model_perf = model_perf.sort_values('wer_median').head(top_n_models)
# Create short model name for display
model_perf['model_short'] = model_perf['model'].apply(
lambda x: x.split('/')[-1] if '/' in x else x
)
# Get colors for each model based on family
model_perf['color'] = model_perf['model_family'].apply(_get_color_for_family)
fig = go.Figure()
fig.add_trace(go.Bar(
y=model_perf['model_short'],
x=model_perf['wer_median'],
orientation='h',
marker=dict(color=model_perf['color']),
text=model_perf['wer_median'].round(2),
textposition='outside',
hovertemplate=(
'%{customdata[0]}
' +
'Family: %{customdata[1]}
' +
'Median WER: %{x:.3f}
' +
'RTFx: %{customdata[2]:.1f}
' +
'Evaluations: %{customdata[3]}
' +
'Samples: %{customdata[4]:,}'
),
customdata=model_perf[['model', 'model_family', 'rtfx_median', 'count', 'total_samples']]
))
# Build title with language info (languages is guaranteed to be non-empty here)
lang_list = languages if languages else []
lang_str = ", ".join(lang_list[:3]) + ("..." if len(lang_list) > 3 else "")
title_text = f"Top {min(top_n_models, len(model_perf))} Models for {lang_str}"
else:
# Model family mode (original behavior)
model_perf = filtered_df.groupby('model_family').agg({
'wer': ['median', 'std', 'count'],
'cer': 'median',
'rtfx': 'median'
}).reset_index()
# Get unique sample counts per model family (avoid double-counting across models)
unique_samples = filtered_df.groupby(['model_family', 'language', 'dataset_group'])['num_samples'].first().reset_index()
family_samples = unique_samples.groupby('model_family')['num_samples'].sum().reset_index()
family_samples.columns = ['model_family', 'total_samples']
model_perf.columns = ['model_family', 'wer_median', 'wer_std', 'count', 'cer_median', 'rtfx_median']
model_perf = model_perf.merge(family_samples, on='model_family', how='left')
model_perf = model_perf.sort_values('wer_median')
# Get colors for each model family
model_perf['color'] = model_perf['model_family'].apply(_get_color_for_family)
fig = go.Figure()
fig.add_trace(go.Bar(
y=model_perf['model_family'],
x=model_perf['wer_median'],
orientation='h',
error_x=dict(type='data', array=model_perf['wer_std']),
marker=dict(color=model_perf['color']),
text=model_perf['wer_median'].round(2),
textposition='outside',
hovertemplate=(
'%{y}
' +
'Median WER: %{x:.3f}
' +
'Std Dev: %{customdata[0]:.3f}
' +
'RTFx: %{customdata[1]:.1f}
' +
'Evaluations: %{customdata[2]}
' +
'Samples: %{customdata[3]:,}'
),
customdata=model_perf[['wer_std', 'rtfx_median', 'count', 'total_samples']]
))
title_text = "Model Family Performance Leaderboard"
# Calculate dynamic height based on number of items
num_items = len(model_perf)
height = max(400, min(700, 100 + num_items * 35))
fig.update_layout(
title=title_text,
xaxis_title="Word Error Rate (WER)",
yaxis_title="",
height=height,
autosize=True,
showlegend=False,
template='plotly_white',
margin=dict(l=200, r=30, t=60, b=60)
)
return fig
def create_cer_leaderboard(df: pd.DataFrame, languages: list[str] | None = None, top_n_models: int = 15) -> go.Figure:
"""
Visualization: CER Model Family / Individual Model Performance Leaderboard
- When no language filter: Shows model families (aggregated)
- When language(s) selected: Shows top N individual models
Outliers are removed for cleaner visualization.
Args:
df: DataFrame with evaluation results
languages: List of languages to filter by (None = all languages)
top_n_models: Number of top individual models to show when languages are filtered (default: 15)
"""
# Apply language filter if provided
filtered_df = df.copy()
if languages:
filtered_df = filtered_df[filtered_df['language'].isin(languages)]
# Remove CER outliers for cleaner visualization (similar to WER outlier removal)
if not filtered_df.empty and 'cer' in filtered_df.columns:
Q1 = filtered_df['cer'].quantile(0.25)
Q3 = filtered_df['cer'].quantile(0.75)
IQR = Q3 - Q1
upper_bound = Q3 + 1.5 * IQR
filtered_df = filtered_df[filtered_df['cer'] <= upper_bound]
# Determine mode: individual models if languages selected, otherwise model families
show_individual_models = languages is not None and len(languages) > 0
if show_individual_models:
# Individual model mode: show top N models by median CER
model_perf = filtered_df.groupby(['model_family', 'model']).agg({
'cer': ['median', 'std', 'count'],
'wer': 'median',
'rtfx': 'median'
}).reset_index()
# Get unique sample counts per model
unique_samples = filtered_df.groupby(['model_family', 'model', 'language', 'dataset_group'])['num_samples'].first().reset_index()
model_samples = unique_samples.groupby(['model_family', 'model'])['num_samples'].sum().reset_index()
model_samples.columns = ['model_family', 'model', 'total_samples']
model_perf.columns = ['model_family', 'model', 'cer_median', 'cer_std', 'count', 'wer_median', 'rtfx_median']
model_perf = model_perf.merge(model_samples, on=['model_family', 'model'], how='left')
model_perf = model_perf.sort_values('cer_median').head(top_n_models)
# Create short model name for display
model_perf['model_short'] = model_perf['model'].apply(
lambda x: x.split('/')[-1] if '/' in x else x
)
# Get colors for each model based on family
model_perf['color'] = model_perf['model_family'].apply(_get_color_for_family)
fig = go.Figure()
fig.add_trace(go.Bar(
y=model_perf['model_short'],
x=model_perf['cer_median'],
orientation='h',
marker=dict(color=model_perf['color']),
text=model_perf['cer_median'].round(2),
textposition='outside',
hovertemplate=(
'%{customdata[0]}
' +
'Family: %{customdata[1]}
' +
'Median CER: %{x:.3f}
' +
'WER: %{customdata[2]:.3f}
' +
'RTFx: %{customdata[3]:.1f}
' +
'Evaluations: %{customdata[4]}
' +
'Samples: %{customdata[5]:,}'
),
customdata=model_perf[['model', 'model_family', 'wer_median', 'rtfx_median', 'count', 'total_samples']]
))
# Build title with language info (languages is guaranteed to be non-empty here)
lang_list = languages if languages else []
lang_str = ", ".join(lang_list[:3]) + ("..." if len(lang_list) > 3 else "")
title_text = f"Top {min(top_n_models, len(model_perf))} Models by CER for {lang_str}"
else:
# Model family mode (original behavior)
model_perf = filtered_df.groupby('model_family').agg({
'cer': ['median', 'std', 'count'],
'wer': 'median',
'rtfx': 'median'
}).reset_index()
# Get unique sample counts per model family (avoid double-counting across models)
unique_samples = filtered_df.groupby(['model_family', 'language', 'dataset_group'])['num_samples'].first().reset_index()
family_samples = unique_samples.groupby('model_family')['num_samples'].sum().reset_index()
family_samples.columns = ['model_family', 'total_samples']
model_perf.columns = ['model_family', 'cer_median', 'cer_std', 'count', 'wer_median', 'rtfx_median']
model_perf = model_perf.merge(family_samples, on='model_family', how='left')
model_perf = model_perf.sort_values('cer_median')
# Get colors for each model family
model_perf['color'] = model_perf['model_family'].apply(_get_color_for_family)
fig = go.Figure()
fig.add_trace(go.Bar(
y=model_perf['model_family'],
x=model_perf['cer_median'],
orientation='h',
error_x=dict(type='data', array=model_perf['cer_std']),
marker=dict(color=model_perf['color']),
text=model_perf['cer_median'].round(2),
textposition='outside',
hovertemplate=(
'%{y}
' +
'Median CER: %{x:.3f}
' +
'Std Dev: %{customdata[0]:.3f}
' +
'WER: %{customdata[1]:.3f}
' +
'RTFx: %{customdata[2]:.1f}
' +
'Evaluations: %{customdata[3]}
' +
'Samples: %{customdata[4]:,}'
),
customdata=model_perf[['cer_std', 'wer_median', 'rtfx_median', 'count', 'total_samples']]
))
title_text = "Model Family Performance by CER"
# Calculate dynamic height based on number of items
num_items = len(model_perf)
height = max(400, min(700, 100 + num_items * 35))
fig.update_layout(
title=title_text,
xaxis_title="Character Error Rate (CER)",
yaxis_title="",
height=height,
autosize=True,
showlegend=False,
template='plotly_white',
margin=dict(l=200, r=30, t=60, b=60)
)
return fig
def create_speed_accuracy_scatter(df: pd.DataFrame, view_mode: str = "model_family", languages: list[str] | None = None) -> go.Figure:
"""
Visualization 2: Speed vs Accuracy Tradeoff
Scatter plot showing the relationship between WER and RTFx with quadrants.
Outliers are removed for cleaner visualization.
Args:
df: DataFrame with evaluation results
view_mode: Either "model_family" (bubbles same size per family, color by family) or
"individual_model" (bubble size = model params, color = model family)
languages: List of languages to filter by (None = all languages)
"""
# Apply language filter if provided
if languages:
df = df[df['language'].isin(languages)]
# Remove WER outliers for cleaner visualization
df = _remove_wer_outliers(df)
if view_mode == "individual_model":
# Aggregate by individual model
model_agg = df.groupby(['model_family', 'model']).agg({
'wer': 'median',
'rtfx': 'median',
'cer': 'median',
}).reset_index()
# Get unique sample counts per model
unique_samples = df.groupby(['model_family', 'model', 'language', 'dataset_group'])['num_samples'].first().reset_index()
model_samples = unique_samples.groupby(['model_family', 'model'])['num_samples'].sum().reset_index()
model_samples.columns = ['model_family', 'model', 'num_samples']
model_agg = model_agg.merge(model_samples, on=['model_family', 'model'], how='left')
# Get parameter count for each individual model
model_agg['params'] = model_agg['model'].apply(
lambda m: MODEL_PARAMETER_COUNTS.get(m, 500_000_000) # Default 500M
)
model_agg['params_billions'] = model_agg['params'] / 1_000_000_000
model_agg['params_display'] = model_agg['params'].apply(
lambda x: f"{x/1_000_000_000:.1f}B" if x >= 1_000_000_000 else f"{x/1_000_000:.0f}M"
)
# Create short model name for display
model_agg['model_short'] = model_agg['model'].apply(
lambda x: x.split('/')[-1] if '/' in x else x
)
fig = go.Figure()
# Add scatter traces for each model family
for family in model_agg['model_family'].unique():
family_data = model_agg[model_agg['model_family'] == family]
family_color = _get_color_for_family(family)
fig.add_trace(go.Scatter(
x=family_data['wer'],
y=family_data['rtfx'],
mode='markers',
name=family,
marker=dict(
size=family_data['params'] / family_data['params'].max() * 50 + 10,
sizemode='diameter',
sizemin=8,
color=family_color
),
customdata=family_data[['model_short', 'cer', 'num_samples', 'params_display']].values,
hovertemplate=(
'%{customdata[0]}
' +
'WER: %{x:.3f}
' +
'RTFx: %{y:.1f}
' +
'CER: %{customdata[1]:.3f}
' +
'Samples: %{customdata[2]:,}
' +
'Parameters: %{customdata[3]}'
)
))
title_text = "Speed vs Accuracy Tradeoff by Individual Model"
else:
# Original behavior: aggregate by model family
model_agg = df.groupby('model_family').agg({
'wer': 'median',
'rtfx': 'median',
'cer': 'median',
'model': 'first' # Get a representative model name for parameter lookup
}).reset_index()
# Get unique sample counts per model family
unique_samples = df.groupby(['model_family', 'language', 'dataset_group'])['num_samples'].first().reset_index()
family_samples = unique_samples.groupby('model_family')['num_samples'].sum().reset_index()
family_samples.columns = ['model_family', 'num_samples']
model_agg = model_agg.merge(family_samples, on='model_family', how='left')
# For model family view, use a uniform size (no bubble size variation)
# Use a constant for params to make bubbles the same size
model_agg['params'] = 1_000_000_000 # Use constant 1B for uniform bubble size
model_agg['params_display'] = 'N/A' # Not applicable in family view
fig = go.Figure()
# Add scatter traces for each model family
for family in model_agg['model_family'].unique():
family_data = model_agg[model_agg['model_family'] == family]
family_color = _get_color_for_family(family)
fig.add_trace(go.Scatter(
x=family_data['wer'],
y=family_data['rtfx'],
mode='markers',
name=family,
marker=dict(
size=20,
sizemode='diameter',
color=family_color
),
customdata=family_data[['cer', 'num_samples']].values,
hovertemplate=(
'' + family + '
' +
'WER: %{x:.3f}
' +
'RTFx: %{y:.1f}
' +
'CER: %{customdata[0]:.3f}
' +
'Samples: %{customdata[1]:,}'
)
))
title_text = "Speed vs Accuracy Tradeoff by Model Family"
# Add quadrant lines at median
median_wer = model_agg['wer'].median()
median_rtfx = model_agg['rtfx'].median()
fig.add_hline(y=median_rtfx, line_dash="dash", line_color="gray", annotation_text="Median RTFx", annotation_position="right")
fig.add_vline(x=median_wer, line_dash="dash", line_color="gray", annotation_text="Median WER", annotation_position="top")
# Add quadrant label - centered in the "Fast & Accurate" quadrant (low WER, high RTFx)
# The ideal quadrant is: x from min to median_wer, y from median_rtfx to max
quadrant_center_x = (model_agg['wer'].min() + median_wer) / 2
quadrant_center_y = (median_rtfx + model_agg['rtfx'].max()) / 2
fig.add_annotation(
x=quadrant_center_x,
y=quadrant_center_y,
text="Fast & Accurate ⭐",
showarrow=False,
font=dict(size=12, color="green", family="Arial Black")
)
fig.update_layout(
title=title_text,
xaxis_title="WER",
yaxis_title="RTFx",
yaxis=dict(rangemode='tozero'), # Ensure y-axis starts at 0 (RTFx can't be negative)
height=550,
autosize=True,
template='plotly_white',
legend=dict(
orientation="h",
yanchor="top",
y=-0.15,
xanchor="center",
x=0.5,
font=dict(size=10)
),
margin=dict(l=60, r=20, t=50, b=120)
)
return fig
def create_wer_cer_correlation(df: pd.DataFrame, languages: list[str] | None = None, top_n_models: int | None = None) -> go.Figure:
"""
Visualization 7: WER vs CER Correlation
Scatter plot showing the relationship between word and character error rates.
Defaults to Swahili if no language is specified.
Outliers are removed for cleaner visualization.
Args:
df: DataFrame with evaluation results
languages: List of languages to filter by (defaults to ["Swahili"] if None)
top_n_models: If specified, only show top N models by WER (0 = show all)
"""
# Default to Swahili if no language filter provided
if not languages:
languages = ["Swahili"]
# Apply language filter
filtered_df = df.copy()
filtered_df = filtered_df[filtered_df['language'].isin(languages)]
# Remove WER outliers for cleaner visualization
filtered_df = _remove_wer_outliers(filtered_df)
# Apply top N models filter if specified
if top_n_models and top_n_models > 0:
model_wer = filtered_df.groupby('model')['wer'].median().sort_values()
top_models = model_wer.head(top_n_models).index.tolist()
filtered_df = filtered_df[filtered_df['model'].isin(top_models)]
fig = go.Figure()
# Add scatter traces for each model family
for family in filtered_df['model_family'].unique():
family_data = filtered_df[filtered_df['model_family'] == family]
family_color = _get_color_for_family(family)
# Normalize size for better visualization
max_samples = filtered_df['num_samples'].max() if not filtered_df.empty else 1
sizes = (family_data['num_samples'] / max_samples * 25 + 5).values if not family_data.empty else [10]
fig.add_trace(go.Scatter(
x=family_data['wer'],
y=family_data['cer'],
mode='markers',
name=family,
marker=dict(
size=sizes,
sizemode='diameter',
sizemin=5,
opacity=0.6,
color=family_color
),
customdata=family_data[['language', 'model', 'dataset_group', 'num_samples']].values,
hovertemplate=(
'%{customdata[0]}
' +
'Model: %{customdata[1]}
' +
'Dataset: %{customdata[2]}
' +
'WER: %{x:.3f}
' +
'CER: %{y:.3f}
' +
'Samples: %{customdata[3]:,}'
)
))
# Add trendline
if len(filtered_df) > 1:
import numpy as np
z = np.polyfit(filtered_df['wer'], filtered_df['cer'], 1)
p = np.poly1d(z)
x_range = np.linspace(filtered_df['wer'].min(), filtered_df['wer'].max(), 100)
fig.add_trace(go.Scatter(
x=x_range,
y=p(x_range),
mode='lines',
name='Trend',
line=dict(color='gray', dash='dash'),
hoverinfo='skip'
))
# Calculate correlation
correlation = filtered_df[['wer', 'cer']].corr().iloc[0, 1] if len(filtered_df) > 1 else 0
# Build title with language info
lang_str = ", ".join(languages[:3]) + ("..." if len(languages) > 3 else "")
title_text = f"WER vs CER Correlation for {lang_str} (r={correlation:.2f})"
fig.update_layout(
title=title_text,
xaxis_title="WER",
yaxis_title="CER",
height=550,
autosize=True,
template='plotly_white',
legend=dict(
orientation="h",
yanchor="top",
y=-0.15,
xanchor="center",
x=0.5,
font=dict(size=10)
),
margin=dict(l=60, r=20, t=50, b=120),
# Zoom X-axis to useful range (0-1.5)
xaxis=dict(range=[0, 1.5]),
yaxis=dict(range=[0, 1.5])
)
return fig
def create_model_consistency(df: pd.DataFrame) -> go.Figure:
"""
Visualization 9: Model Consistency Analysis
Shows coefficient of variation (CV) to measure consistency across languages.
Removes high outliers only using IQR method (keeps best performers).
"""
# Remove only HIGH outliers using IQR method on WER (keep best performers)
Q1 = df['wer'].quantile(0.25)
Q3 = df['wer'].quantile(0.75)
IQR = Q3 - Q1
upper_bound = Q3 + 1.5 * IQR
df_no_outliers = df[df['wer'] <= upper_bound]
model_variance = df_no_outliers.groupby('model_family').agg({
'wer': ['median', 'std', 'count']
}).reset_index()
model_variance.columns = ['model_family', 'wer_median', 'wer_std', 'count']
model_variance['cv'] = (model_variance['wer_std'] / model_variance['wer_median'] * 100)
model_variance = model_variance.sort_values('cv')
# Get colors for each model family
model_variance['color'] = model_variance['model_family'].apply(_get_color_for_family)
fig = go.Figure()
fig.add_trace(go.Bar(
y=model_variance['model_family'],
x=model_variance['cv'],
orientation='h',
marker=dict(
color=model_variance['color']
),
text=model_variance['cv'].round(1),
textposition='outside',
hovertemplate=(
'%{y}
' +
'Coefficient of Variation: %{x:.1f}%
' +
'Median WER: %{customdata[0]:.3f}
' +
'Std Dev: %{customdata[1]:.3f}
' +
'Evaluations: %{customdata[2]}'
),
customdata=model_variance[['wer_median', 'wer_std', 'count']]
))
fig.update_layout(
title="Model Consistency Ranking (Outliers Removed)",
xaxis_title="Coefficient of Variation (%)",
yaxis_title="Model Family",
height=550,
template='plotly_white',
showlegend=False,
margin=dict(l=200, r=100, t=80, b=80)
)
return fig