File size: 8,309 Bytes
2d79471
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
from __future__ import annotations

from typing import Dict, List, Any
import pandas as pd


def _safe_ratio(n: int, d: int) -> float:
    return float(n / d) if d else 0.0


def infer_root_causes(df_events: pd.DataFrame, df_align: pd.DataFrame) -> Dict[str, Any]:
    """
    Rule/statistics based root-cause inference.
    Input:
      - df_events: events.parquet loaded as DataFrame
      - df_align: aligned.jsonl loaded as DataFrame
    Output:
      - dict with evidence, issue hypotheses, and recommendations
    """
    result: Dict[str, Any] = {
        "overview": {},
        "root_causes": [],
        "evidence_tables": {},
    }

    total_events = len(df_events)
    total_utts = len(df_align)

    result["overview"] = {
        "num_utterances": int(total_utts),
        "num_error_events": int(total_events),
        "wer_mean": float(df_align["wer"].dropna().mean()) if "wer" in df_align.columns and df_align["wer"].notna().any() else None,
        "cer_mean": float(df_align["cer"].dropna().mean()) if "cer" in df_align.columns and df_align["cer"].notna().any() else None,
    }

    if total_events == 0:
        result["root_causes"].append({
            "cause": "no_errors_detected",
            "confidence": 1.0,
            "evidence": ["No error events found in current run."],
            "recommendations": ["Use a weaker model or more difficult dataset to make diagnosis meaningful."]
        })
        return result

    # Basic counts
    op_counts = df_events["op_type"].value_counts().to_dict() if "op_type" in df_events.columns else {}
    cls_counts = df_events["error_class"].value_counts().to_dict() if "error_class" in df_events.columns else {}

    result["evidence_tables"]["op_counts"] = {k: int(v) for k, v in op_counts.items()}
    result["evidence_tables"]["error_class_counts"] = {k: int(v) for k, v in cls_counts.items()}

    # --- Cause 1: number/time normalization problems
    num_time_count = int(cls_counts.get("number_or_time", 0))
    if _safe_ratio(num_time_count, total_events) >= 0.15:
        result["root_causes"].append({
            "cause": "number_time_format",
            "confidence": round(min(0.95, 0.5 + _safe_ratio(num_time_count, total_events)), 3),
            "evidence": [
                f"number_or_time events = {num_time_count}/{total_events}",
                "Large proportion of errors are related to numbers, dates, times, or units."
            ],
            "recommendations": [
                "Add number/date/time normalization in both reference and hypothesis.",
                "Create post-processing rules for time/unit expressions.",
                "Add more number-heavy utterances into evaluation/training."
            ]
        })

    # --- Cause 2: mixed-language problems
    mixed_count = int(cls_counts.get("mixed_language", 0))
    if _safe_ratio(mixed_count, total_events) >= 0.10:
        result["root_causes"].append({
            "cause": "mixed_language",
            "confidence": round(min(0.95, 0.45 + _safe_ratio(mixed_count, total_events)), 3),
            "evidence": [
                f"mixed_language events = {mixed_count}/{total_events}",
                "Frequent English/Latin-token related substitutions suggest code-switching weakness."
            ],
            "recommendations": [
                "Add bilingual/code-switching evaluation samples.",
                "Add domain-specific English terms, abbreviations, and brand names.",
                "Add post-processing lexicon for mixed-language phrases."
            ]
        })

    # --- Cause 3: deletion-heavy => possible noise / far-field / VAD
    deletion_count = int(op_counts.get("D", 0))
    insertion_count = int(op_counts.get("I", 0))
    substitution_count = int(op_counts.get("S", 0))

    if _safe_ratio(deletion_count, total_events) >= 0.30:
        result["root_causes"].append({
            "cause": "noise_or_farfield_or_vad",
            "confidence": round(min(0.95, 0.5 + _safe_ratio(deletion_count, total_events)), 3),
            "evidence": [
                f"deletion events = {deletion_count}/{total_events}",
                "High deletion ratio often indicates weak audibility, noise, far-field speech, or segmentation/VAD issues."
            ],
            "recommendations": [
                "Compare CER/WER across device / SNR / domain slices.",
                "Inspect quiet, noisy, or long utterances.",
                "Tune VAD or segmentation strategy.",
                "Add noisy / far-field augmented audio."
            ]
        })

    # --- Cause 4: insertion-heavy => possible segmentation/repetition/echo
    if _safe_ratio(insertion_count, total_events) >= 0.20:
        result["root_causes"].append({
            "cause": "segmentation_or_repetition",
            "confidence": round(min(0.9, 0.45 + _safe_ratio(insertion_count, total_events)), 3),
            "evidence": [
                f"insertion events = {insertion_count}/{total_events}",
                "High insertion ratio often suggests repeated decoding, segmentation mismatch, or echo."
            ],
            "recommendations": [
                "Inspect duplicated filler words and repeated fragments.",
                "Review chunking / segmentation.",
                "Check whether punctuation or normalization creates false insertions."
            ]
        })

    # --- Cause 5: slice-based evidence (device/domain/accent/speaker)
    slice_findings = []
    for key in ["device", "domain", "accent", "speaker"]:
        if key in df_align.columns and df_align[key].notna().any() and "cer" in df_align.columns:
            g = df_align.groupby(key)["cer"].mean().dropna().sort_values(ascending=False)
            if len(g) >= 2:
                worst_key = str(g.index[0])
                worst_val = float(g.iloc[0])
                best_val = float(g.iloc[-1])
                if best_val > 0 and worst_val / best_val >= 1.5:
                    slice_findings.append({
                        "slice_key": key,
                        "worst_group": worst_key,
                        "worst_cer": worst_val,
                        "best_cer": best_val,
                        "ratio": worst_val / best_val
                    })

    if slice_findings:
        result["evidence_tables"]["slice_findings"] = slice_findings
        result["root_causes"].append({
            "cause": "slice_specific_weakness",
            "confidence": 0.85,
            "evidence": [
                "Some slices show much worse CER than others.",
                *[
                    f"{x['slice_key']}={x['worst_group']} has CER {x['worst_cer']:.4f}, ratio vs best={x['ratio']:.2f}"
                    for x in slice_findings[:5]
                ]
            ],
            "recommendations": [
                "Prioritize the worst slices in future analysis/training.",
                "Check whether those slices correspond to accent, device, or scenario mismatch."
            ]
        })

    # --- Cause 6: substitution-dominant => pronunciation / lexical confusion
    if _safe_ratio(substitution_count, total_events) >= 0.60:
        result["root_causes"].append({
            "cause": "pronunciation_or_lexical_confusion",
            "confidence": round(min(0.9, 0.45 + _safe_ratio(substitution_count, total_events)), 3),
            "evidence": [
                f"substitution events = {substitution_count}/{total_events}",
                "Substitutions dominate, which often indicates pronunciation ambiguity, lexical confusion, or near-homophone errors."
            ],
            "recommendations": [
                "Add confusion-pair statistics.",
                "Check near-homophone and accent-sensitive confusions.",
                "Build a pronunciation-aware analysis layer."
            ]
        })

    if not result["root_causes"]:
        result["root_causes"].append({
            "cause": "general_asr_mismatch",
            "confidence": 0.5,
            "evidence": ["No single dominant root cause identified from current heuristics."],
            "recommendations": [
                "Inspect top confusion pairs and low-performing slices.",
                "Increase metadata coverage (device/domain/accent/snr)."
            ]
        })

    return result