Spaces:
Running
Running
| """ | |
| Data loading, filtering, aggregation and metric computation for PazaBench. | |
| This module handles all data processing logic: | |
| - Loading ASR results from CSV (with HuggingFace Hub fallback) | |
| - Filtering dataframes by model, language, dataset, region | |
| - Aggregating results by language with proper RTFx computation | |
| - Building metric pivot tables and DataFrames | |
| """ | |
| from functools import lru_cache | |
| from pathlib import Path | |
| import pandas as pd | |
| from huggingface_hub import hf_hub_download | |
| from src.constants import ( | |
| ASR_DISPLAY_COLUMNS, | |
| ASR_NUMERIC_COLUMNS, | |
| ASR_TEXT_COLUMNS, | |
| DEFAULT_VIEW_MODE, | |
| FILTER_COLUMN_ORDER, | |
| LANGUAGE_NAME_MAPPING, | |
| METRIC_CONFIGS, | |
| RESULTS_CSV_FILENAME, | |
| RESULTS_CSV_PATH, | |
| VIEW_MODE_COLUMNS, | |
| ) | |
| from src.display.styling import format_metric_value | |
| from src.envs import HF_ENABLED, RESULTS_REPO, TOKEN | |
| from src.language_metadata import get_language_regions | |
| # ============================================================================= | |
| # Data Loading | |
| # ============================================================================= | |
| def _cached_asr_results(csv_path: str) -> pd.DataFrame: | |
| """Load ASR results from HuggingFace Hub or fall back to local file.""" | |
| path = Path(csv_path) | |
| if HF_ENABLED and RESULTS_REPO: | |
| try: | |
| print(f"Downloading {RESULTS_CSV_FILENAME} from {RESULTS_REPO}...") | |
| downloaded_path = hf_hub_download( | |
| repo_id=RESULTS_REPO, | |
| filename=RESULTS_CSV_FILENAME, | |
| repo_type="dataset", | |
| token=TOKEN, | |
| ) | |
| path = Path(downloaded_path) | |
| print(f"Successfully downloaded results from {RESULTS_REPO}") | |
| except Exception as e: | |
| print(f"Could not download from HuggingFace Hub: {e}") | |
| print(f"Falling back to local file at {csv_path}") | |
| path = Path(csv_path) | |
| if not path.exists(): | |
| raise FileNotFoundError( | |
| f"ASR results CSV not found at {path}. " | |
| "Please generate it via src.aggregate_results before launching the app." | |
| ) | |
| frame = pd.read_csv(path) | |
| # Convert numeric columns | |
| for column in ASR_NUMERIC_COLUMNS: | |
| if column in frame.columns: | |
| frame[column] = pd.to_numeric(frame[column], errors="coerce") | |
| # Fill missing text columns | |
| for column in ASR_TEXT_COLUMNS: | |
| if column in frame.columns: | |
| frame[column] = frame[column].fillna("Unknown") | |
| # Normalize language names | |
| if "language" in frame.columns: | |
| frame["language"] = frame["language"].replace(LANGUAGE_NAME_MAPPING) | |
| # Filter out rows with very few samples | |
| MIN_SAMPLES_THRESHOLD = 10 | |
| if "num_samples" in frame.columns: | |
| frame = frame[frame["num_samples"] >= MIN_SAMPLES_THRESHOLD] | |
| # Add region metadata column | |
| if "language" in frame.columns: | |
| frame["african_region"] = frame["language"].apply( | |
| lambda x: ", ".join(get_language_regions(x)) | |
| ) | |
| return frame | |
| def load_asr_results(csv_path: Path = RESULTS_CSV_PATH) -> pd.DataFrame: | |
| """Load ASR results DataFrame.""" | |
| return _cached_asr_results(str(csv_path)) | |
| # ============================================================================= | |
| # Filtering | |
| # ============================================================================= | |
| def _sorted_column_values(frame: pd.DataFrame, column: str) -> list[str]: | |
| """Get sorted unique values from a column, with 'Unknown' at the end.""" | |
| if column not in frame.columns or frame.empty: | |
| return [] | |
| values = sorted({value for value in frame[column].dropna().unique() if value != "Unknown"}) | |
| if (frame[column] == "Unknown").any(): | |
| values.append("Unknown") | |
| return values | |
| def get_filter_options(frame: pd.DataFrame) -> dict[str, list[str]]: | |
| """Get all available filter options from the DataFrame.""" | |
| from src.about import get_dataset_group_label | |
| from src.language_metadata import get_all_regions | |
| options = {column: _sorted_column_values(frame, column) for column in FILTER_COLUMN_ORDER} | |
| options["dataset_group_labels"] = [ | |
| get_dataset_group_label(dg) for dg in options.get("dataset_group", []) | |
| ] | |
| options["african_region"] = get_all_regions() | |
| return options | |
| def get_languages_for_filters( | |
| frame: pd.DataFrame, | |
| african_regions: list[str] | None = None, | |
| ) -> list[str]: | |
| """Get languages that match the given region filter.""" | |
| from src.language_metadata import get_languages_by_region | |
| if not african_regions: | |
| return _sorted_column_values(frame, "language") | |
| region_languages: set[str] = set() | |
| for region in african_regions: | |
| region_languages.update(get_languages_by_region(region)) | |
| available_languages = set(_sorted_column_values(frame, "language")) | |
| return sorted(region_languages & available_languages) | |
| def filter_asr_dataframe( | |
| frame: pd.DataFrame, | |
| *, | |
| models: list[str] | None = None, | |
| languages: list[str] | None = None, | |
| dataset_groups: list[str] | None = None, | |
| african_regions: list[str] | None = None, | |
| ) -> pd.DataFrame: | |
| """Filter the ASR results DataFrame by the given criteria.""" | |
| from src.language_metadata import get_languages_by_region | |
| filtered = frame.copy() | |
| if models: | |
| filtered = filtered[filtered["model"].isin(models)] | |
| if languages: | |
| filtered = filtered[filtered["language"].isin(languages)] | |
| if dataset_groups: | |
| filtered = filtered[filtered["dataset_group"].isin(dataset_groups)] | |
| # Apply region filter by getting languages for selected regions | |
| if african_regions: | |
| region_languages: set[str] = set() | |
| for region in african_regions: | |
| region_languages.update(get_languages_by_region(region)) | |
| if region_languages and "language" in filtered.columns: | |
| filtered = filtered[filtered["language"].isin(region_languages)] | |
| return filtered | |
| def prepare_display_dataframe( | |
| frame: pd.DataFrame, | |
| max_rows: int, | |
| include_split: bool = True | |
| ) -> pd.DataFrame: | |
| """Prepare a DataFrame for display with proper formatting.""" | |
| columns = [col for col in ASR_DISPLAY_COLUMNS if col in frame.columns] | |
| if not include_split and "split" in columns: | |
| columns = [col for col in columns if col != "split"] | |
| display = frame.loc[:, columns].copy() | |
| for column in ["wer", "cer", "rtfx"]: | |
| if column in display.columns: | |
| display[column] = display[column].round(3) | |
| for column in ["duration_sec", "inference_time_sec"]: | |
| if column in display.columns: | |
| display[column] = display[column].round(2) | |
| display = display.head(max_rows) | |
| display.insert(0, "", range(1, len(display) + 1)) | |
| return display | |
| # ============================================================================= | |
| # Aggregation | |
| # ============================================================================= | |
| def _join_unique_values(values: pd.Series) -> str | None: | |
| """Join unique non-empty values from a Series.""" | |
| if values is None: | |
| return None | |
| unique_values = [str(v) for v in values.dropna().unique() if str(v).strip()] | |
| return ", ".join(unique_values) if unique_values else None | |
| def _weighted_average(series: pd.Series, weights: pd.Series) -> float | None: | |
| """Compute weighted average of a series.""" | |
| numeric = pd.to_numeric(series, errors="coerce").dropna() | |
| if numeric.empty: | |
| return None | |
| aligned_weights = weights.loc[numeric.index].fillna(0).astype(float) | |
| total_weight = aligned_weights.sum() | |
| if total_weight <= 0: | |
| return float(round(numeric.mean(), 4)) | |
| return float(round((numeric * aligned_weights).sum() / total_weight, 4)) | |
| def _compute_rtfx_from_totals( | |
| duration_series: pd.Series, | |
| inference_time_series: pd.Series | |
| ) -> float | None: | |
| """Compute RTFx as Total Audio Duration / Total Transcription Time.""" | |
| duration_numeric = pd.to_numeric(duration_series, errors="coerce").dropna() | |
| inference_numeric = pd.to_numeric(inference_time_series, errors="coerce").dropna() | |
| common_index = duration_numeric.index.intersection(inference_numeric.index) | |
| if common_index.empty: | |
| return None | |
| total_duration = duration_numeric.loc[common_index].sum() | |
| total_inference = inference_numeric.loc[common_index].sum() | |
| if total_inference <= 0: | |
| return None | |
| return float(round(total_duration / total_inference, 4)) | |
| def aggregate_by_language(frame: pd.DataFrame) -> pd.DataFrame: | |
| """Aggregate results by model and language with proper metric computation.""" | |
| if frame.empty or "language" not in frame.columns: | |
| return frame | |
| group_keys = [col for col in ["model_family", "model", "language"] if col in frame.columns] | |
| if not group_keys: | |
| return frame | |
| aggregated_rows: list[dict[str, object]] = [] | |
| for _, group in frame.groupby(group_keys, dropna=False): | |
| weight_series = group["num_samples"] if "num_samples" in group.columns else pd.Series([1] * len(group)) | |
| weight_series = pd.to_numeric(weight_series, errors="coerce").fillna(0).astype(float) | |
| # Compute RTFx: sum(duration) / sum(inference_time) | |
| rtfx_value = None | |
| if "duration_sec" in group.columns and "inference_time_sec" in group.columns: | |
| rtfx_value = _compute_rtfx_from_totals(group["duration_sec"], group["inference_time_sec"]) | |
| # Sum durations and inference times | |
| total_duration = None | |
| total_inference = None | |
| if "duration_sec" in group.columns: | |
| duration_numeric = pd.to_numeric(group["duration_sec"], errors="coerce").dropna() | |
| if not duration_numeric.empty: | |
| total_duration = float(round(duration_numeric.sum(), 4)) | |
| if "inference_time_sec" in group.columns: | |
| inference_numeric = pd.to_numeric(group["inference_time_sec"], errors="coerce").dropna() | |
| if not inference_numeric.empty: | |
| total_inference = float(round(inference_numeric.sum(), 4)) | |
| aggregated_rows.append({ | |
| "model_family": group.get("model_family", pd.Series([None])).iloc[0] if "model_family" in group else None, | |
| "model": group.get("model", pd.Series([None])).iloc[0] if "model" in group else None, | |
| "dataset_group": _join_unique_values(group["dataset_group"]) if "dataset_group" in group.columns else None, | |
| "language": group.get("language", pd.Series([None])).iloc[0] if "language" in group else None, | |
| "region": _join_unique_values(group["region"]) if "region" in group.columns else None, | |
| "wer": _weighted_average(group["wer"], weight_series) if "wer" in group.columns else None, | |
| "cer": _weighted_average(group["cer"], weight_series) if "cer" in group.columns else None, | |
| "rtfx": rtfx_value, | |
| "duration_sec": total_duration, | |
| "inference_time_sec": total_inference, | |
| "num_samples": int(weight_series.sum()) if weight_series.sum() > 0 else len(group), | |
| }) | |
| return pd.DataFrame(aggregated_rows) | |
| # ============================================================================= | |
| # Metric Tables | |
| # ============================================================================= | |
| def _pivot_metric_table(aggregated: pd.DataFrame, metric: str, view_mode: str) -> pd.DataFrame: | |
| """Create a pivot table for a specific metric.""" | |
| if aggregated.empty or metric not in aggregated.columns: | |
| return pd.DataFrame() | |
| column_key = VIEW_MODE_COLUMNS.get(view_mode, VIEW_MODE_COLUMNS[DEFAULT_VIEW_MODE]) | |
| if column_key not in aggregated.columns: | |
| fallback = "model" if "model" in aggregated.columns else None | |
| if fallback is None: | |
| return pd.DataFrame() | |
| column_key = fallback | |
| pivot = ( | |
| aggregated | |
| .pivot_table(index="language", columns=column_key, values=metric, aggfunc="median") | |
| .sort_index() | |
| ) | |
| pivot = pivot.dropna(how="all") | |
| if pivot.empty: | |
| return pivot | |
| # Sort columns by performance | |
| column_scores = pivot.mean(skipna=True) | |
| ascending = METRIC_CONFIGS.get(metric, {}).get("better", "lower") == "lower" | |
| ordered_columns = column_scores.sort_values(ascending=ascending).index.tolist() | |
| # Ensure all columns are included | |
| missing = [col for col in pivot.columns if col not in ordered_columns] | |
| ordered_columns.extend(missing) | |
| return pivot[ordered_columns] | |
| def _build_metric_table_html(pivot: pd.DataFrame, metric: str) -> str: | |
| """Build HTML table for a metric pivot table.""" | |
| config = METRIC_CONFIGS[metric] | |
| if pivot.empty: | |
| return f"<p class='metric-table-empty'>No {config['label']} data available for the current filters.</p>" | |
| header_cells = "".join(f"<th>{col}</th>" for col in pivot.columns) | |
| rows_html: list[str] = [] | |
| for i, (language, row) in enumerate(pivot.iterrows(), 1): | |
| cell_html = [] | |
| for column in pivot.columns: | |
| value = row[column] | |
| display = format_metric_value(value, config["fmt"]) | |
| cell_html.append(f"<td>{display}</td>") | |
| onclick = ( | |
| "(function(el){var tr=el.parentElement;" | |
| "var isHighlighted=tr.classList.contains('row-highlighted');" | |
| "tr.parentElement.querySelectorAll('tr').forEach(function(r){" | |
| "r.classList.remove('row-highlighted')});" | |
| "if(!isHighlighted){tr.classList.add('row-highlighted')}})(this)" | |
| ) | |
| rows_html.append( | |
| f"<tr><td>{i}</td><th onclick=\"{onclick}\">{language}</th>{''.join(cell_html)}</tr>" | |
| ) | |
| caption = f"<caption>Columns sorted by overall {config['label']} performance</caption>" | |
| table_html = ( | |
| "<div class='metric-table-wrapper'>" | |
| f"<table class='metric-table'>{caption}<thead><tr><th>#</th><th>Language</th>{header_cells}</tr></thead>" | |
| f"<tbody>{''.join(rows_html)}</tbody></table></div>" | |
| ) | |
| return table_html | |
| def _build_metric_dataframe(pivot: pd.DataFrame, metric: str) -> pd.DataFrame: | |
| """Build a sortable DataFrame from the pivot table.""" | |
| if pivot.empty: | |
| return pd.DataFrame() | |
| df = pivot.reset_index() | |
| df = df.rename(columns={"language": "Language"}) | |
| # Round numeric columns | |
| for col in df.columns: | |
| if col != "Language" and df[col].dtype in ['float64', 'float32']: | |
| df[col] = df[col].round(2) | |
| df.insert(0, "", range(1, len(df) + 1)) | |
| return df | |
| def compute_metric_tables( | |
| models: list[str] | None, | |
| languages: list[str] | None, | |
| dataset_groups: list[str] | None, | |
| view_mode: str, | |
| african_regions: list[str] | None = None, | |
| asr_results_df: pd.DataFrame | None = None, | |
| ) -> dict[str, str]: | |
| """Compute HTML metric tables for all metrics.""" | |
| if asr_results_df is None or asr_results_df.empty: | |
| return {metric: "<p class='metric-table-empty'>No ASR results loaded.</p>" for metric in METRIC_CONFIGS} | |
| filtered = filter_asr_dataframe( | |
| asr_results_df, | |
| models=models, | |
| languages=languages, | |
| dataset_groups=dataset_groups, | |
| african_regions=african_regions, | |
| ) | |
| aggregated = aggregate_by_language(filtered) | |
| tables: dict[str, str] = {} | |
| for metric in METRIC_CONFIGS: | |
| pivot = _pivot_metric_table(aggregated, metric, view_mode) | |
| tables[metric] = _build_metric_table_html(pivot, metric) | |
| return tables | |
| def compute_metric_dataframes( | |
| models: list[str] | None, | |
| languages: list[str] | None, | |
| dataset_groups: list[str] | None, | |
| view_mode: str, | |
| african_regions: list[str] | None = None, | |
| asr_results_df: pd.DataFrame | None = None, | |
| ) -> dict[str, pd.DataFrame]: | |
| """Compute metric DataFrames for interactive sorting.""" | |
| if asr_results_df is None or asr_results_df.empty: | |
| return {metric: pd.DataFrame() for metric in METRIC_CONFIGS} | |
| filtered = filter_asr_dataframe( | |
| asr_results_df, | |
| models=models, | |
| languages=languages, | |
| dataset_groups=dataset_groups, | |
| african_regions=african_regions, | |
| ) | |
| aggregated = aggregate_by_language(filtered) | |
| dataframes: dict[str, pd.DataFrame] = {} | |
| for metric in METRIC_CONFIGS: | |
| pivot = _pivot_metric_table(aggregated, metric, view_mode) | |
| dataframes[metric] = _build_metric_dataframe(pivot, metric) | |
| return dataframes | |
| # ============================================================================= | |
| # Helper Functions | |
| # ============================================================================= | |
| def strip_dataset_label(label: str) -> str: | |
| """Strip descriptor from dataset label. | |
| E.g., 'ALFFA (read speech & broadcast news, 3 languages)' -> 'ALFFA' | |
| """ | |
| if " (" in label: | |
| return label.split(" (")[0] | |
| return label | |
| def strip_dataset_labels(labels: list[str] | None) -> list[str]: | |
| """Strip descriptors from a list of dataset labels.""" | |
| if not labels: | |
| return [] | |
| return [strip_dataset_label(label) for label in labels] | |