from __future__ import annotations from typing import Dict, List, Any import pandas as pd def _safe_ratio(n: int, d: int) -> float: return float(n / d) if d else 0.0 def infer_root_causes(df_events: pd.DataFrame, df_align: pd.DataFrame) -> Dict[str, Any]: """ Rule/statistics based root-cause inference. Input: - df_events: events.parquet loaded as DataFrame - df_align: aligned.jsonl loaded as DataFrame Output: - dict with evidence, issue hypotheses, and recommendations """ result: Dict[str, Any] = { "overview": {}, "root_causes": [], "evidence_tables": {}, } total_events = len(df_events) total_utts = len(df_align) result["overview"] = { "num_utterances": int(total_utts), "num_error_events": int(total_events), "wer_mean": float(df_align["wer"].dropna().mean()) if "wer" in df_align.columns and df_align["wer"].notna().any() else None, "cer_mean": float(df_align["cer"].dropna().mean()) if "cer" in df_align.columns and df_align["cer"].notna().any() else None, } if total_events == 0: result["root_causes"].append({ "cause": "no_errors_detected", "confidence": 1.0, "evidence": ["No error events found in current run."], "recommendations": ["Use a weaker model or more difficult dataset to make diagnosis meaningful."] }) return result # Basic counts op_counts = df_events["op_type"].value_counts().to_dict() if "op_type" in df_events.columns else {} cls_counts = df_events["error_class"].value_counts().to_dict() if "error_class" in df_events.columns else {} result["evidence_tables"]["op_counts"] = {k: int(v) for k, v in op_counts.items()} result["evidence_tables"]["error_class_counts"] = {k: int(v) for k, v in cls_counts.items()} # --- Cause 1: number/time normalization problems num_time_count = int(cls_counts.get("number_or_time", 0)) if _safe_ratio(num_time_count, total_events) >= 0.15: result["root_causes"].append({ "cause": "number_time_format", "confidence": round(min(0.95, 0.5 + _safe_ratio(num_time_count, total_events)), 3), "evidence": [ f"number_or_time events = {num_time_count}/{total_events}", "Large proportion of errors are related to numbers, dates, times, or units." ], "recommendations": [ "Add number/date/time normalization in both reference and hypothesis.", "Create post-processing rules for time/unit expressions.", "Add more number-heavy utterances into evaluation/training." ] }) # --- Cause 2: mixed-language problems mixed_count = int(cls_counts.get("mixed_language", 0)) if _safe_ratio(mixed_count, total_events) >= 0.10: result["root_causes"].append({ "cause": "mixed_language", "confidence": round(min(0.95, 0.45 + _safe_ratio(mixed_count, total_events)), 3), "evidence": [ f"mixed_language events = {mixed_count}/{total_events}", "Frequent English/Latin-token related substitutions suggest code-switching weakness." ], "recommendations": [ "Add bilingual/code-switching evaluation samples.", "Add domain-specific English terms, abbreviations, and brand names.", "Add post-processing lexicon for mixed-language phrases." ] }) # --- Cause 3: deletion-heavy => possible noise / far-field / VAD deletion_count = int(op_counts.get("D", 0)) insertion_count = int(op_counts.get("I", 0)) substitution_count = int(op_counts.get("S", 0)) if _safe_ratio(deletion_count, total_events) >= 0.30: result["root_causes"].append({ "cause": "noise_or_farfield_or_vad", "confidence": round(min(0.95, 0.5 + _safe_ratio(deletion_count, total_events)), 3), "evidence": [ f"deletion events = {deletion_count}/{total_events}", "High deletion ratio often indicates weak audibility, noise, far-field speech, or segmentation/VAD issues." ], "recommendations": [ "Compare CER/WER across device / SNR / domain slices.", "Inspect quiet, noisy, or long utterances.", "Tune VAD or segmentation strategy.", "Add noisy / far-field augmented audio." ] }) # --- Cause 4: insertion-heavy => possible segmentation/repetition/echo if _safe_ratio(insertion_count, total_events) >= 0.20: result["root_causes"].append({ "cause": "segmentation_or_repetition", "confidence": round(min(0.9, 0.45 + _safe_ratio(insertion_count, total_events)), 3), "evidence": [ f"insertion events = {insertion_count}/{total_events}", "High insertion ratio often suggests repeated decoding, segmentation mismatch, or echo." ], "recommendations": [ "Inspect duplicated filler words and repeated fragments.", "Review chunking / segmentation.", "Check whether punctuation or normalization creates false insertions." ] }) # --- Cause 5: slice-based evidence (device/domain/accent/speaker) slice_findings = [] for key in ["device", "domain", "accent", "speaker"]: if key in df_align.columns and df_align[key].notna().any() and "cer" in df_align.columns: g = df_align.groupby(key)["cer"].mean().dropna().sort_values(ascending=False) if len(g) >= 2: worst_key = str(g.index[0]) worst_val = float(g.iloc[0]) best_val = float(g.iloc[-1]) if best_val > 0 and worst_val / best_val >= 1.5: slice_findings.append({ "slice_key": key, "worst_group": worst_key, "worst_cer": worst_val, "best_cer": best_val, "ratio": worst_val / best_val }) if slice_findings: result["evidence_tables"]["slice_findings"] = slice_findings result["root_causes"].append({ "cause": "slice_specific_weakness", "confidence": 0.85, "evidence": [ "Some slices show much worse CER than others.", *[ f"{x['slice_key']}={x['worst_group']} has CER {x['worst_cer']:.4f}, ratio vs best={x['ratio']:.2f}" for x in slice_findings[:5] ] ], "recommendations": [ "Prioritize the worst slices in future analysis/training.", "Check whether those slices correspond to accent, device, or scenario mismatch." ] }) # --- Cause 6: substitution-dominant => pronunciation / lexical confusion if _safe_ratio(substitution_count, total_events) >= 0.60: result["root_causes"].append({ "cause": "pronunciation_or_lexical_confusion", "confidence": round(min(0.9, 0.45 + _safe_ratio(substitution_count, total_events)), 3), "evidence": [ f"substitution events = {substitution_count}/{total_events}", "Substitutions dominate, which often indicates pronunciation ambiguity, lexical confusion, or near-homophone errors." ], "recommendations": [ "Add confusion-pair statistics.", "Check near-homophone and accent-sensitive confusions.", "Build a pronunciation-aware analysis layer." ] }) if not result["root_causes"]: result["root_causes"].append({ "cause": "general_asr_mismatch", "confidence": 0.5, "evidence": ["No single dominant root cause identified from current heuristics."], "recommendations": [ "Inspect top confusion pairs and low-performing slices.", "Increase metadata coverage (device/domain/accent/snr)." ] }) return result