paza-bench / src /data_processing.py
muchai-mercy's picture
update pazabench space
53a73e0
"""
Data loading, filtering, aggregation and metric computation for PazaBench.
This module handles all data processing logic:
- Loading ASR results from CSV (with HuggingFace Hub fallback)
- Filtering dataframes by model, language, dataset, region
- Aggregating results by language with proper RTFx computation
- Building metric pivot tables and DataFrames
"""
from functools import lru_cache
from pathlib import Path
import pandas as pd
from huggingface_hub import hf_hub_download
from src.constants import (
ASR_DISPLAY_COLUMNS,
ASR_NUMERIC_COLUMNS,
ASR_TEXT_COLUMNS,
DEFAULT_VIEW_MODE,
FILTER_COLUMN_ORDER,
LANGUAGE_NAME_MAPPING,
METRIC_CONFIGS,
RESULTS_CSV_FILENAME,
RESULTS_CSV_PATH,
VIEW_MODE_COLUMNS,
)
from src.display.styling import format_metric_value
from src.envs import HF_ENABLED, RESULTS_REPO, TOKEN
from src.language_metadata import get_language_regions
# =============================================================================
# Data Loading
# =============================================================================
@lru_cache(maxsize=1)
def _cached_asr_results(csv_path: str) -> pd.DataFrame:
"""Load ASR results from HuggingFace Hub or fall back to local file."""
path = Path(csv_path)
if HF_ENABLED and RESULTS_REPO:
try:
print(f"Downloading {RESULTS_CSV_FILENAME} from {RESULTS_REPO}...")
downloaded_path = hf_hub_download(
repo_id=RESULTS_REPO,
filename=RESULTS_CSV_FILENAME,
repo_type="dataset",
token=TOKEN,
)
path = Path(downloaded_path)
print(f"Successfully downloaded results from {RESULTS_REPO}")
except Exception as e:
print(f"Could not download from HuggingFace Hub: {e}")
print(f"Falling back to local file at {csv_path}")
path = Path(csv_path)
if not path.exists():
raise FileNotFoundError(
f"ASR results CSV not found at {path}. "
"Please generate it via src.aggregate_results before launching the app."
)
frame = pd.read_csv(path)
# Convert numeric columns
for column in ASR_NUMERIC_COLUMNS:
if column in frame.columns:
frame[column] = pd.to_numeric(frame[column], errors="coerce")
# Fill missing text columns
for column in ASR_TEXT_COLUMNS:
if column in frame.columns:
frame[column] = frame[column].fillna("Unknown")
# Normalize language names
if "language" in frame.columns:
frame["language"] = frame["language"].replace(LANGUAGE_NAME_MAPPING)
# Filter out rows with very few samples
MIN_SAMPLES_THRESHOLD = 10
if "num_samples" in frame.columns:
frame = frame[frame["num_samples"] >= MIN_SAMPLES_THRESHOLD]
# Add region metadata column
if "language" in frame.columns:
frame["african_region"] = frame["language"].apply(
lambda x: ", ".join(get_language_regions(x))
)
return frame
def load_asr_results(csv_path: Path = RESULTS_CSV_PATH) -> pd.DataFrame:
"""Load ASR results DataFrame."""
return _cached_asr_results(str(csv_path))
# =============================================================================
# Filtering
# =============================================================================
def _sorted_column_values(frame: pd.DataFrame, column: str) -> list[str]:
"""Get sorted unique values from a column, with 'Unknown' at the end."""
if column not in frame.columns or frame.empty:
return []
values = sorted({value for value in frame[column].dropna().unique() if value != "Unknown"})
if (frame[column] == "Unknown").any():
values.append("Unknown")
return values
def get_filter_options(frame: pd.DataFrame) -> dict[str, list[str]]:
"""Get all available filter options from the DataFrame."""
from src.about import get_dataset_group_label
from src.language_metadata import get_all_regions
options = {column: _sorted_column_values(frame, column) for column in FILTER_COLUMN_ORDER}
options["dataset_group_labels"] = [
get_dataset_group_label(dg) for dg in options.get("dataset_group", [])
]
options["african_region"] = get_all_regions()
return options
def get_languages_for_filters(
frame: pd.DataFrame,
african_regions: list[str] | None = None,
) -> list[str]:
"""Get languages that match the given region filter."""
from src.language_metadata import get_languages_by_region
if not african_regions:
return _sorted_column_values(frame, "language")
region_languages: set[str] = set()
for region in african_regions:
region_languages.update(get_languages_by_region(region))
available_languages = set(_sorted_column_values(frame, "language"))
return sorted(region_languages & available_languages)
def filter_asr_dataframe(
frame: pd.DataFrame,
*,
models: list[str] | None = None,
languages: list[str] | None = None,
dataset_groups: list[str] | None = None,
african_regions: list[str] | None = None,
) -> pd.DataFrame:
"""Filter the ASR results DataFrame by the given criteria."""
from src.language_metadata import get_languages_by_region
filtered = frame.copy()
if models:
filtered = filtered[filtered["model"].isin(models)]
if languages:
filtered = filtered[filtered["language"].isin(languages)]
if dataset_groups:
filtered = filtered[filtered["dataset_group"].isin(dataset_groups)]
# Apply region filter by getting languages for selected regions
if african_regions:
region_languages: set[str] = set()
for region in african_regions:
region_languages.update(get_languages_by_region(region))
if region_languages and "language" in filtered.columns:
filtered = filtered[filtered["language"].isin(region_languages)]
return filtered
def prepare_display_dataframe(
frame: pd.DataFrame,
max_rows: int,
include_split: bool = True
) -> pd.DataFrame:
"""Prepare a DataFrame for display with proper formatting."""
columns = [col for col in ASR_DISPLAY_COLUMNS if col in frame.columns]
if not include_split and "split" in columns:
columns = [col for col in columns if col != "split"]
display = frame.loc[:, columns].copy()
for column in ["wer", "cer", "rtfx"]:
if column in display.columns:
display[column] = display[column].round(3)
for column in ["duration_sec", "inference_time_sec"]:
if column in display.columns:
display[column] = display[column].round(2)
display = display.head(max_rows)
display.insert(0, "", range(1, len(display) + 1))
return display
# =============================================================================
# Aggregation
# =============================================================================
def _join_unique_values(values: pd.Series) -> str | None:
"""Join unique non-empty values from a Series."""
if values is None:
return None
unique_values = [str(v) for v in values.dropna().unique() if str(v).strip()]
return ", ".join(unique_values) if unique_values else None
def _weighted_average(series: pd.Series, weights: pd.Series) -> float | None:
"""Compute weighted average of a series."""
numeric = pd.to_numeric(series, errors="coerce").dropna()
if numeric.empty:
return None
aligned_weights = weights.loc[numeric.index].fillna(0).astype(float)
total_weight = aligned_weights.sum()
if total_weight <= 0:
return float(round(numeric.mean(), 4))
return float(round((numeric * aligned_weights).sum() / total_weight, 4))
def _compute_rtfx_from_totals(
duration_series: pd.Series,
inference_time_series: pd.Series
) -> float | None:
"""Compute RTFx as Total Audio Duration / Total Transcription Time."""
duration_numeric = pd.to_numeric(duration_series, errors="coerce").dropna()
inference_numeric = pd.to_numeric(inference_time_series, errors="coerce").dropna()
common_index = duration_numeric.index.intersection(inference_numeric.index)
if common_index.empty:
return None
total_duration = duration_numeric.loc[common_index].sum()
total_inference = inference_numeric.loc[common_index].sum()
if total_inference <= 0:
return None
return float(round(total_duration / total_inference, 4))
def aggregate_by_language(frame: pd.DataFrame) -> pd.DataFrame:
"""Aggregate results by model and language with proper metric computation."""
if frame.empty or "language" not in frame.columns:
return frame
group_keys = [col for col in ["model_family", "model", "language"] if col in frame.columns]
if not group_keys:
return frame
aggregated_rows: list[dict[str, object]] = []
for _, group in frame.groupby(group_keys, dropna=False):
weight_series = group["num_samples"] if "num_samples" in group.columns else pd.Series([1] * len(group))
weight_series = pd.to_numeric(weight_series, errors="coerce").fillna(0).astype(float)
# Compute RTFx: sum(duration) / sum(inference_time)
rtfx_value = None
if "duration_sec" in group.columns and "inference_time_sec" in group.columns:
rtfx_value = _compute_rtfx_from_totals(group["duration_sec"], group["inference_time_sec"])
# Sum durations and inference times
total_duration = None
total_inference = None
if "duration_sec" in group.columns:
duration_numeric = pd.to_numeric(group["duration_sec"], errors="coerce").dropna()
if not duration_numeric.empty:
total_duration = float(round(duration_numeric.sum(), 4))
if "inference_time_sec" in group.columns:
inference_numeric = pd.to_numeric(group["inference_time_sec"], errors="coerce").dropna()
if not inference_numeric.empty:
total_inference = float(round(inference_numeric.sum(), 4))
aggregated_rows.append({
"model_family": group.get("model_family", pd.Series([None])).iloc[0] if "model_family" in group else None,
"model": group.get("model", pd.Series([None])).iloc[0] if "model" in group else None,
"dataset_group": _join_unique_values(group["dataset_group"]) if "dataset_group" in group.columns else None,
"language": group.get("language", pd.Series([None])).iloc[0] if "language" in group else None,
"region": _join_unique_values(group["region"]) if "region" in group.columns else None,
"wer": _weighted_average(group["wer"], weight_series) if "wer" in group.columns else None,
"cer": _weighted_average(group["cer"], weight_series) if "cer" in group.columns else None,
"rtfx": rtfx_value,
"duration_sec": total_duration,
"inference_time_sec": total_inference,
"num_samples": int(weight_series.sum()) if weight_series.sum() > 0 else len(group),
})
return pd.DataFrame(aggregated_rows)
# =============================================================================
# Metric Tables
# =============================================================================
def _pivot_metric_table(aggregated: pd.DataFrame, metric: str, view_mode: str) -> pd.DataFrame:
"""Create a pivot table for a specific metric."""
if aggregated.empty or metric not in aggregated.columns:
return pd.DataFrame()
column_key = VIEW_MODE_COLUMNS.get(view_mode, VIEW_MODE_COLUMNS[DEFAULT_VIEW_MODE])
if column_key not in aggregated.columns:
fallback = "model" if "model" in aggregated.columns else None
if fallback is None:
return pd.DataFrame()
column_key = fallback
pivot = (
aggregated
.pivot_table(index="language", columns=column_key, values=metric, aggfunc="median")
.sort_index()
)
pivot = pivot.dropna(how="all")
if pivot.empty:
return pivot
# Sort columns by performance
column_scores = pivot.mean(skipna=True)
ascending = METRIC_CONFIGS.get(metric, {}).get("better", "lower") == "lower"
ordered_columns = column_scores.sort_values(ascending=ascending).index.tolist()
# Ensure all columns are included
missing = [col for col in pivot.columns if col not in ordered_columns]
ordered_columns.extend(missing)
return pivot[ordered_columns]
def _build_metric_table_html(pivot: pd.DataFrame, metric: str) -> str:
"""Build HTML table for a metric pivot table."""
config = METRIC_CONFIGS[metric]
if pivot.empty:
return f"<p class='metric-table-empty'>No {config['label']} data available for the current filters.</p>"
header_cells = "".join(f"<th>{col}</th>" for col in pivot.columns)
rows_html: list[str] = []
for i, (language, row) in enumerate(pivot.iterrows(), 1):
cell_html = []
for column in pivot.columns:
value = row[column]
display = format_metric_value(value, config["fmt"])
cell_html.append(f"<td>{display}</td>")
onclick = (
"(function(el){var tr=el.parentElement;"
"var isHighlighted=tr.classList.contains('row-highlighted');"
"tr.parentElement.querySelectorAll('tr').forEach(function(r){"
"r.classList.remove('row-highlighted')});"
"if(!isHighlighted){tr.classList.add('row-highlighted')}})(this)"
)
rows_html.append(
f"<tr><td>{i}</td><th onclick=\"{onclick}\">{language}</th>{''.join(cell_html)}</tr>"
)
caption = f"<caption>Columns sorted by overall {config['label']} performance</caption>"
table_html = (
"<div class='metric-table-wrapper'>"
f"<table class='metric-table'>{caption}<thead><tr><th>#</th><th>Language</th>{header_cells}</tr></thead>"
f"<tbody>{''.join(rows_html)}</tbody></table></div>"
)
return table_html
def _build_metric_dataframe(pivot: pd.DataFrame, metric: str) -> pd.DataFrame:
"""Build a sortable DataFrame from the pivot table."""
if pivot.empty:
return pd.DataFrame()
df = pivot.reset_index()
df = df.rename(columns={"language": "Language"})
# Round numeric columns
for col in df.columns:
if col != "Language" and df[col].dtype in ['float64', 'float32']:
df[col] = df[col].round(2)
df.insert(0, "", range(1, len(df) + 1))
return df
def compute_metric_tables(
models: list[str] | None,
languages: list[str] | None,
dataset_groups: list[str] | None,
view_mode: str,
african_regions: list[str] | None = None,
asr_results_df: pd.DataFrame | None = None,
) -> dict[str, str]:
"""Compute HTML metric tables for all metrics."""
if asr_results_df is None or asr_results_df.empty:
return {metric: "<p class='metric-table-empty'>No ASR results loaded.</p>" for metric in METRIC_CONFIGS}
filtered = filter_asr_dataframe(
asr_results_df,
models=models,
languages=languages,
dataset_groups=dataset_groups,
african_regions=african_regions,
)
aggregated = aggregate_by_language(filtered)
tables: dict[str, str] = {}
for metric in METRIC_CONFIGS:
pivot = _pivot_metric_table(aggregated, metric, view_mode)
tables[metric] = _build_metric_table_html(pivot, metric)
return tables
def compute_metric_dataframes(
models: list[str] | None,
languages: list[str] | None,
dataset_groups: list[str] | None,
view_mode: str,
african_regions: list[str] | None = None,
asr_results_df: pd.DataFrame | None = None,
) -> dict[str, pd.DataFrame]:
"""Compute metric DataFrames for interactive sorting."""
if asr_results_df is None or asr_results_df.empty:
return {metric: pd.DataFrame() for metric in METRIC_CONFIGS}
filtered = filter_asr_dataframe(
asr_results_df,
models=models,
languages=languages,
dataset_groups=dataset_groups,
african_regions=african_regions,
)
aggregated = aggregate_by_language(filtered)
dataframes: dict[str, pd.DataFrame] = {}
for metric in METRIC_CONFIGS:
pivot = _pivot_metric_table(aggregated, metric, view_mode)
dataframes[metric] = _build_metric_dataframe(pivot, metric)
return dataframes
# =============================================================================
# Helper Functions
# =============================================================================
def strip_dataset_label(label: str) -> str:
"""Strip descriptor from dataset label.
E.g., 'ALFFA (read speech & broadcast news, 3 languages)' -> 'ALFFA'
"""
if " (" in label:
return label.split(" (")[0]
return label
def strip_dataset_labels(labels: list[str] | None) -> list[str]:
"""Strip descriptors from a list of dataset labels."""
if not labels:
return []
return [strip_dataset_label(label) for label in labels]