ASR_AGENT_ / analysis /aggregate.py
unknown
Update wer and cer
d7df0a5
from __future__ import annotations
from typing import Dict
import pandas as pd
LANGS = ["zh", "en", "mixed", "other"]
LEVELS = ["char", "word"]
def _level_summary(df_events: pd.DataFrame, level: str) -> Dict:
out: Dict = {}
if df_events is None or len(df_events) == 0:
return out
q = df_events[df_events["level"] == level].copy() if "level" in df_events.columns else df_events.copy()
if len(q) == 0:
return out
out["sid_counts"] = {str(k): int(v) for k, v in q["op_type"].value_counts().to_dict().items()} if "op_type" in q.columns else {}
out["top_error_classes"] = {str(k): int(v) for k, v in q["error_class"].value_counts().head(20).to_dict().items()} if "error_class" in q.columns else {}
if "op_type" in q.columns:
sub = q[q["op_type"] == "S"].copy()
if len(sub) > 0:
sub["pair"] = sub["ref"].astype(str) + " -> " + sub["hyp"].astype(str)
out["top_confusions"] = {str(k): int(v) for k, v in sub["pair"].value_counts().head(30).to_dict().items()}
out["num_events"] = int(len(q))
return out
def _lang_level_summary(df_align: pd.DataFrame, df_events: pd.DataFrame, lang: str) -> Dict:
out: Dict = {}
qa = df_align[df_align["lang_type"] == lang].copy() if "lang_type" in df_align.columns else pd.DataFrame()
qe = df_events[df_events["lang_type"] == lang].copy() if (len(df_events) > 0 and "lang_type" in df_events.columns) else pd.DataFrame()
if len(qa) == 0:
return out
out["num_utts"] = int(len(qa))
if qa["cer"].notna().any():
out["cer_mean"] = float(qa["cer"].dropna().mean())
if qa["wer"].notna().any():
out["wer_mean"] = float(qa["wer"].dropna().mean())
if qa["primary_metric_value"].notna().any():
out["primary_metric_mean"] = float(qa["primary_metric_value"].dropna().mean())
for level in LEVELS:
level_out = _level_summary(qe, level)
if level_out:
out[f"{level}_view"] = level_out
return out
def aggregate_summary(df_events: pd.DataFrame, df_align: pd.DataFrame) -> Dict:
summary: Dict = {"overall": {}, "primary_metrics": {}, "char_view": {}, "word_view": {}, "by_language": {}}
summary["overall"]["num_utts"] = int(len(df_align)) if df_align is not None else 0
if df_align is not None and len(df_align) > 0 and "lang_type" in df_align.columns:
summary["overall"]["lang_distribution"] = {str(k): int(v) for k, v in df_align["lang_type"].value_counts().to_dict().items()}
summary["overall"]["primary_level_distribution"] = {str(k): int(v) for k, v in df_align["primary_level"].value_counts().to_dict().items()}
if df_align is not None and len(df_align) > 0:
if "wer" in df_align.columns and df_align["wer"].notna().any():
summary["wer_mean"] = float(df_align["wer"].dropna().mean())
if "cer" in df_align.columns and df_align["cer"].notna().any():
summary["cer_mean"] = float(df_align["cer"].dropna().mean())
if "primary_metric_value" in df_align.columns and df_align["primary_metric_value"].notna().any():
summary["primary_metrics"]["overall_primary_error_mean"] = float(df_align["primary_metric_value"].dropna().mean())
zh = df_align[df_align.get("lang_type") == "zh"] if "lang_type" in df_align.columns else pd.DataFrame()
en = df_align[df_align.get("lang_type") == "en"] if "lang_type" in df_align.columns else pd.DataFrame()
mixed = df_align[df_align.get("lang_type") == "mixed"] if "lang_type" in df_align.columns else pd.DataFrame()
if len(zh) > 0 and zh["cer"].notna().any():
summary["primary_metrics"]["zh_cer_mean"] = float(zh["cer"].dropna().mean())
if len(en) > 0 and en["wer"].notna().any():
summary["primary_metrics"]["en_wer_mean"] = float(en["wer"].dropna().mean())
if len(mixed) > 0 and mixed["primary_metric_value"].notna().any():
summary["primary_metrics"]["mixed_primary_error_mean"] = float(mixed["primary_metric_value"].dropna().mean())
summary["char_view"] = _level_summary(df_events, "char")
summary["word_view"] = _level_summary(df_events, "word")
default_view = "char"
if summary["overall"].get("lang_distribution", {}).get("en", 0) > summary["overall"].get("lang_distribution", {}).get("zh", 0):
default_view = "word"
summary["default_event_view"] = default_view
primary_events = df_events[df_events["is_primary_level"] == True].copy() if (df_events is not None and len(df_events) > 0 and "is_primary_level" in df_events.columns) else pd.DataFrame()
summary["primary_view"] = _level_summary(primary_events, level=default_view)
summary["sid_counts"] = summary["primary_view"].get("sid_counts", {})
summary["top_error_classes"] = summary["primary_view"].get("top_error_classes", {})
summary["top_confusions"] = summary["primary_view"].get("top_confusions", {})
for lang in LANGS:
lang_out = _lang_level_summary(df_align, df_events, lang)
if lang_out:
summary["by_language"][lang] = lang_out
for key, metric in [("device", "cer"), ("domain", "cer"), ("accent", "cer"), ("speaker", "primary_metric_value")]:
if df_align is not None and len(df_align) > 0 and key in df_align.columns and metric in df_align.columns:
g = df_align.groupby(key)[metric].mean().dropna().sort_values(ascending=False).head(10)
if len(g) > 0:
summary[f"worst_{key}_by_{metric}"] = [{"key": str(k), metric: float(v)} for k, v in g.items()]
return summary