""" Data loading, filtering, aggregation and metric computation for PazaBench. This module handles all data processing logic: - Loading ASR results from CSV (with HuggingFace Hub fallback) - Filtering dataframes by model, language, dataset, region - Aggregating results by language with proper RTFx computation - Building metric pivot tables and DataFrames """ from functools import lru_cache from pathlib import Path import pandas as pd from huggingface_hub import hf_hub_download from src.constants import ( ASR_DISPLAY_COLUMNS, ASR_NUMERIC_COLUMNS, ASR_TEXT_COLUMNS, DEFAULT_VIEW_MODE, FILTER_COLUMN_ORDER, LANGUAGE_NAME_MAPPING, METRIC_CONFIGS, RESULTS_CSV_FILENAME, RESULTS_CSV_PATH, VIEW_MODE_COLUMNS, ) from src.display.styling import format_metric_value from src.envs import HF_ENABLED, RESULTS_REPO, TOKEN from src.language_metadata import get_language_regions # ============================================================================= # Data Loading # ============================================================================= @lru_cache(maxsize=1) def _cached_asr_results(csv_path: str) -> pd.DataFrame: """Load ASR results from HuggingFace Hub or fall back to local file.""" path = Path(csv_path) if HF_ENABLED and RESULTS_REPO: try: print(f"Downloading {RESULTS_CSV_FILENAME} from {RESULTS_REPO}...") downloaded_path = hf_hub_download( repo_id=RESULTS_REPO, filename=RESULTS_CSV_FILENAME, repo_type="dataset", token=TOKEN, ) path = Path(downloaded_path) print(f"Successfully downloaded results from {RESULTS_REPO}") except Exception as e: print(f"Could not download from HuggingFace Hub: {e}") print(f"Falling back to local file at {csv_path}") path = Path(csv_path) if not path.exists(): raise FileNotFoundError( f"ASR results CSV not found at {path}. " "Please generate it via src.aggregate_results before launching the app." ) frame = pd.read_csv(path) # Convert numeric columns for column in ASR_NUMERIC_COLUMNS: if column in frame.columns: frame[column] = pd.to_numeric(frame[column], errors="coerce") # Fill missing text columns for column in ASR_TEXT_COLUMNS: if column in frame.columns: frame[column] = frame[column].fillna("Unknown") # Normalize language names if "language" in frame.columns: frame["language"] = frame["language"].replace(LANGUAGE_NAME_MAPPING) # Filter out rows with very few samples MIN_SAMPLES_THRESHOLD = 10 if "num_samples" in frame.columns: frame = frame[frame["num_samples"] >= MIN_SAMPLES_THRESHOLD] # Add region metadata column if "language" in frame.columns: frame["african_region"] = frame["language"].apply( lambda x: ", ".join(get_language_regions(x)) ) return frame def load_asr_results(csv_path: Path = RESULTS_CSV_PATH) -> pd.DataFrame: """Load ASR results DataFrame.""" return _cached_asr_results(str(csv_path)) # ============================================================================= # Filtering # ============================================================================= def _sorted_column_values(frame: pd.DataFrame, column: str) -> list[str]: """Get sorted unique values from a column, with 'Unknown' at the end.""" if column not in frame.columns or frame.empty: return [] values = sorted({value for value in frame[column].dropna().unique() if value != "Unknown"}) if (frame[column] == "Unknown").any(): values.append("Unknown") return values def get_filter_options(frame: pd.DataFrame) -> dict[str, list[str]]: """Get all available filter options from the DataFrame.""" from src.about import get_dataset_group_label from src.language_metadata import get_all_regions options = {column: _sorted_column_values(frame, column) for column in FILTER_COLUMN_ORDER} options["dataset_group_labels"] = [ get_dataset_group_label(dg) for dg in options.get("dataset_group", []) ] options["african_region"] = get_all_regions() return options def get_languages_for_filters( frame: pd.DataFrame, african_regions: list[str] | None = None, ) -> list[str]: """Get languages that match the given region filter.""" from src.language_metadata import get_languages_by_region if not african_regions: return _sorted_column_values(frame, "language") region_languages: set[str] = set() for region in african_regions: region_languages.update(get_languages_by_region(region)) available_languages = set(_sorted_column_values(frame, "language")) return sorted(region_languages & available_languages) def filter_asr_dataframe( frame: pd.DataFrame, *, models: list[str] | None = None, languages: list[str] | None = None, dataset_groups: list[str] | None = None, african_regions: list[str] | None = None, ) -> pd.DataFrame: """Filter the ASR results DataFrame by the given criteria.""" from src.language_metadata import get_languages_by_region filtered = frame.copy() if models: filtered = filtered[filtered["model"].isin(models)] if languages: filtered = filtered[filtered["language"].isin(languages)] if dataset_groups: filtered = filtered[filtered["dataset_group"].isin(dataset_groups)] # Apply region filter by getting languages for selected regions if african_regions: region_languages: set[str] = set() for region in african_regions: region_languages.update(get_languages_by_region(region)) if region_languages and "language" in filtered.columns: filtered = filtered[filtered["language"].isin(region_languages)] return filtered def prepare_display_dataframe( frame: pd.DataFrame, max_rows: int, include_split: bool = True ) -> pd.DataFrame: """Prepare a DataFrame for display with proper formatting.""" columns = [col for col in ASR_DISPLAY_COLUMNS if col in frame.columns] if not include_split and "split" in columns: columns = [col for col in columns if col != "split"] display = frame.loc[:, columns].copy() for column in ["wer", "cer", "rtfx"]: if column in display.columns: display[column] = display[column].round(3) for column in ["duration_sec", "inference_time_sec"]: if column in display.columns: display[column] = display[column].round(2) display = display.head(max_rows) display.insert(0, "", range(1, len(display) + 1)) return display # ============================================================================= # Aggregation # ============================================================================= def _join_unique_values(values: pd.Series) -> str | None: """Join unique non-empty values from a Series.""" if values is None: return None unique_values = [str(v) for v in values.dropna().unique() if str(v).strip()] return ", ".join(unique_values) if unique_values else None def _weighted_average(series: pd.Series, weights: pd.Series) -> float | None: """Compute weighted average of a series.""" numeric = pd.to_numeric(series, errors="coerce").dropna() if numeric.empty: return None aligned_weights = weights.loc[numeric.index].fillna(0).astype(float) total_weight = aligned_weights.sum() if total_weight <= 0: return float(round(numeric.mean(), 4)) return float(round((numeric * aligned_weights).sum() / total_weight, 4)) def _compute_rtfx_from_totals( duration_series: pd.Series, inference_time_series: pd.Series ) -> float | None: """Compute RTFx as Total Audio Duration / Total Transcription Time.""" duration_numeric = pd.to_numeric(duration_series, errors="coerce").dropna() inference_numeric = pd.to_numeric(inference_time_series, errors="coerce").dropna() common_index = duration_numeric.index.intersection(inference_numeric.index) if common_index.empty: return None total_duration = duration_numeric.loc[common_index].sum() total_inference = inference_numeric.loc[common_index].sum() if total_inference <= 0: return None return float(round(total_duration / total_inference, 4)) def aggregate_by_language(frame: pd.DataFrame) -> pd.DataFrame: """Aggregate results by model and language with proper metric computation.""" if frame.empty or "language" not in frame.columns: return frame group_keys = [col for col in ["model_family", "model", "language"] if col in frame.columns] if not group_keys: return frame aggregated_rows: list[dict[str, object]] = [] for _, group in frame.groupby(group_keys, dropna=False): weight_series = group["num_samples"] if "num_samples" in group.columns else pd.Series([1] * len(group)) weight_series = pd.to_numeric(weight_series, errors="coerce").fillna(0).astype(float) # Compute RTFx: sum(duration) / sum(inference_time) rtfx_value = None if "duration_sec" in group.columns and "inference_time_sec" in group.columns: rtfx_value = _compute_rtfx_from_totals(group["duration_sec"], group["inference_time_sec"]) # Sum durations and inference times total_duration = None total_inference = None if "duration_sec" in group.columns: duration_numeric = pd.to_numeric(group["duration_sec"], errors="coerce").dropna() if not duration_numeric.empty: total_duration = float(round(duration_numeric.sum(), 4)) if "inference_time_sec" in group.columns: inference_numeric = pd.to_numeric(group["inference_time_sec"], errors="coerce").dropna() if not inference_numeric.empty: total_inference = float(round(inference_numeric.sum(), 4)) aggregated_rows.append({ "model_family": group.get("model_family", pd.Series([None])).iloc[0] if "model_family" in group else None, "model": group.get("model", pd.Series([None])).iloc[0] if "model" in group else None, "dataset_group": _join_unique_values(group["dataset_group"]) if "dataset_group" in group.columns else None, "language": group.get("language", pd.Series([None])).iloc[0] if "language" in group else None, "region": _join_unique_values(group["region"]) if "region" in group.columns else None, "wer": _weighted_average(group["wer"], weight_series) if "wer" in group.columns else None, "cer": _weighted_average(group["cer"], weight_series) if "cer" in group.columns else None, "rtfx": rtfx_value, "duration_sec": total_duration, "inference_time_sec": total_inference, "num_samples": int(weight_series.sum()) if weight_series.sum() > 0 else len(group), }) return pd.DataFrame(aggregated_rows) # ============================================================================= # Metric Tables # ============================================================================= def _pivot_metric_table(aggregated: pd.DataFrame, metric: str, view_mode: str) -> pd.DataFrame: """Create a pivot table for a specific metric.""" if aggregated.empty or metric not in aggregated.columns: return pd.DataFrame() column_key = VIEW_MODE_COLUMNS.get(view_mode, VIEW_MODE_COLUMNS[DEFAULT_VIEW_MODE]) if column_key not in aggregated.columns: fallback = "model" if "model" in aggregated.columns else None if fallback is None: return pd.DataFrame() column_key = fallback pivot = ( aggregated .pivot_table(index="language", columns=column_key, values=metric, aggfunc="median") .sort_index() ) pivot = pivot.dropna(how="all") if pivot.empty: return pivot # Sort columns by performance column_scores = pivot.mean(skipna=True) ascending = METRIC_CONFIGS.get(metric, {}).get("better", "lower") == "lower" ordered_columns = column_scores.sort_values(ascending=ascending).index.tolist() # Ensure all columns are included missing = [col for col in pivot.columns if col not in ordered_columns] ordered_columns.extend(missing) return pivot[ordered_columns] def _build_metric_table_html(pivot: pd.DataFrame, metric: str) -> str: """Build HTML table for a metric pivot table.""" config = METRIC_CONFIGS[metric] if pivot.empty: return f"

No {config['label']} data available for the current filters.

" header_cells = "".join(f"{col}" for col in pivot.columns) rows_html: list[str] = [] for i, (language, row) in enumerate(pivot.iterrows(), 1): cell_html = [] for column in pivot.columns: value = row[column] display = format_metric_value(value, config["fmt"]) cell_html.append(f"{display}") onclick = ( "(function(el){var tr=el.parentElement;" "var isHighlighted=tr.classList.contains('row-highlighted');" "tr.parentElement.querySelectorAll('tr').forEach(function(r){" "r.classList.remove('row-highlighted')});" "if(!isHighlighted){tr.classList.add('row-highlighted')}})(this)" ) rows_html.append( f"{i}{language}{''.join(cell_html)}" ) caption = f"Columns sorted by overall {config['label']} performance" table_html = ( "
" f"{caption}{header_cells}" f"{''.join(rows_html)}
#Language
" ) return table_html def _build_metric_dataframe(pivot: pd.DataFrame, metric: str) -> pd.DataFrame: """Build a sortable DataFrame from the pivot table.""" if pivot.empty: return pd.DataFrame() df = pivot.reset_index() df = df.rename(columns={"language": "Language"}) # Round numeric columns for col in df.columns: if col != "Language" and df[col].dtype in ['float64', 'float32']: df[col] = df[col].round(2) df.insert(0, "", range(1, len(df) + 1)) return df def compute_metric_tables( models: list[str] | None, languages: list[str] | None, dataset_groups: list[str] | None, view_mode: str, african_regions: list[str] | None = None, asr_results_df: pd.DataFrame | None = None, ) -> dict[str, str]: """Compute HTML metric tables for all metrics.""" if asr_results_df is None or asr_results_df.empty: return {metric: "

No ASR results loaded.

" for metric in METRIC_CONFIGS} filtered = filter_asr_dataframe( asr_results_df, models=models, languages=languages, dataset_groups=dataset_groups, african_regions=african_regions, ) aggregated = aggregate_by_language(filtered) tables: dict[str, str] = {} for metric in METRIC_CONFIGS: pivot = _pivot_metric_table(aggregated, metric, view_mode) tables[metric] = _build_metric_table_html(pivot, metric) return tables def compute_metric_dataframes( models: list[str] | None, languages: list[str] | None, dataset_groups: list[str] | None, view_mode: str, african_regions: list[str] | None = None, asr_results_df: pd.DataFrame | None = None, ) -> dict[str, pd.DataFrame]: """Compute metric DataFrames for interactive sorting.""" if asr_results_df is None or asr_results_df.empty: return {metric: pd.DataFrame() for metric in METRIC_CONFIGS} filtered = filter_asr_dataframe( asr_results_df, models=models, languages=languages, dataset_groups=dataset_groups, african_regions=african_regions, ) aggregated = aggregate_by_language(filtered) dataframes: dict[str, pd.DataFrame] = {} for metric in METRIC_CONFIGS: pivot = _pivot_metric_table(aggregated, metric, view_mode) dataframes[metric] = _build_metric_dataframe(pivot, metric) return dataframes # ============================================================================= # Helper Functions # ============================================================================= def strip_dataset_label(label: str) -> str: """Strip descriptor from dataset label. E.g., 'ALFFA (read speech & broadcast news, 3 languages)' -> 'ALFFA' """ if " (" in label: return label.split(" (")[0] return label def strip_dataset_labels(labels: list[str] | None) -> list[str]: """Strip descriptors from a list of dataset labels.""" if not labels: return [] return [strip_dataset_label(label) for label in labels]