File size: 3,885 Bytes
3c8a5a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
"""
NLU eval harness — intent accuracy, threshold sweep, per-intent recall.
=======================================================================
Usage:
    python eval_nlu.py                 # leave-one-out over INTENT_EXAMPLES
    python eval_nlu.py labeled.csv     # external set, columns: text,intent

Leave-one-out is a *sanity* tool: for each example phrase, it is held out of
its own intent's centroid, then classified. It tells you whether the centroids
are internally separable and lets you sweep CONFIDENCE_THRESHOLD. It does NOT
substitute for a held-out, human-labeled Hausa test set — collect that from the
turn logs (logging_util.py) and pass it as the CSV argument.
"""
from __future__ import annotations

import csv
import sys
from collections import defaultdict

import numpy as np
from sentence_transformers import SentenceTransformer

from nlu import INTENT_EXAMPLES, EMBEDDING_MODEL_ID, CONFIDENCE_THRESHOLD

THRESHOLDS = [0.30, 0.35, 0.40, 0.45, 0.50, 0.55, 0.60]


def _centroid(vecs, exclude=None):
    if exclude is not None:
        vecs = [v for i, v in enumerate(vecs) if i != exclude]
    c = np.mean(vecs, axis=0)
    return c / np.linalg.norm(c)


def _build_records(encoder):
    """Returns list of (text, gold_intent, pred_intent, confidence)."""
    emb = {it: encoder.encode(ph, normalize_embeddings=True)
           for it, ph in INTENT_EXAMPLES.items()}

    if len(sys.argv) > 1:
        rows = list(csv.DictReader(open(sys.argv[1], encoding="utf-8")))
        cents = {it: _centroid(list(v)) for it, v in emb.items()}
        records = []
        for r in rows:
            q = encoder.encode(r["text"], normalize_embeddings=True)
            scores = {it: float(np.dot(q, c)) for it, c in cents.items()}
            best = max(scores, key=scores.get)
            records.append((r["text"], r["intent"], best, scores[best]))
        return records

    # leave-one-out over the examples themselves
    records = []
    for it, vecs in emb.items():
        vecs = list(vecs)
        for i, q in enumerate(vecs):
            cents = {jt: (_centroid(list(v), exclude=i) if jt == it else _centroid(list(v)))
                     for jt, v in emb.items()}
            scores = {jt: float(np.dot(q, c)) for jt, c in cents.items()}
            best = max(scores, key=scores.get)
            records.append((INTENT_EXAMPLES[it][i], it, best, scores[best]))
    return records


def main():
    print(f"Loading {EMBEDDING_MODEL_ID} …")
    encoder = SentenceTransformer(EMBEDDING_MODEL_ID, device="cpu")
    records = _build_records(encoder)
    n = len(records)
    mode = "external CSV" if len(sys.argv) > 1 else "leave-one-out"
    print(f"Evaluating {n} utterances ({mode}).\n")

    print("threshold   accuracy   coverage   (below thresh → predicted 'unknown')")
    for th in THRESHOLDS:
        correct = cov = 0
        for _, gold, pred, conf in records:
            p = pred if conf >= th else "unknown"
            if p != "unknown":
                cov += 1
            if p == gold:
                correct += 1
        marker = "  <- current" if abs(th - CONFIDENCE_THRESHOLD) < 1e-9 else ""
        print(f"  {th:.2f}       {correct/n:.3f}      {cov/n:.3f}{marker}")

    th = CONFIDENCE_THRESHOLD
    confusion = defaultdict(lambda: defaultdict(int))
    for _, gold, pred, conf in records:
        p = pred if conf >= th else "unknown"
        confusion[gold][p] += 1

    print(f"\nPer-intent recall @ {th:.2f}:")
    for gold in sorted(confusion):
        total = sum(confusion[gold].values())
        hit = confusion[gold][gold]
        worst = sorted(((c, p) for p, c in confusion[gold].items() if p != gold), reverse=True)
        leak = f"  (most confused → {worst[0][1]}×{worst[0][0]})" if worst else ""
        print(f"  {gold:14s} {hit}/{total}  recall={hit/total:.2f}{leak}")


if __name__ == "__main__":
    main()