File size: 19,018 Bytes
14a5b1e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
"""
Evaluate all three G.U.I.D.E. models and print train/validation metrics.

Usage:
    # NER + NextActionPredictor only (no CFPB CSV needed)
    python scripts/evaluate_models.py --skip_classifier

    # All three models (run on Kaggle where CFPB CSV is available)
    python scripts/evaluate_models.py --cfpb_csv /kaggle/input/datasets/sharav95/complaint/complaints.csv

Models are downloaded automatically from sarav95/guide-models on HuggingFace
if not already present locally. Set HF_TOKEN env var if needed.
"""

from __future__ import annotations

import argparse
import json
import logging
import os
import sys
from pathlib import Path

import torch

sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
logging.basicConfig(level=logging.WARNING)

_HF_REPO = "sarav95/guide-models"
_ROOT = Path(__file__).resolve().parents[1]


def _ensure_models() -> None:
    """Download model checkpoints from HuggingFace if any are missing."""
    models_dir = _ROOT / "models"
    evidence_ner_ok = (models_dir / "evidence_ner" / "config.json").exists()
    classifier_ok   = (models_dir / "domain_classifier" / "config.json").exists()
    next_action_ok  = (models_dir / "next_action" / "model.pt").exists()

    if evidence_ner_ok and classifier_ok and next_action_ok:
        return

    print(f"  Model checkpoints missing — downloading from {_HF_REPO!r} …")
    try:
        from huggingface_hub import snapshot_download
    except ImportError:
        print("  [error] huggingface_hub not installed: pip install huggingface_hub")
        sys.exit(1)

    models_dir.mkdir(parents=True, exist_ok=True)
    token = os.environ.get("HF_TOKEN")
    snapshot_download(
        repo_id=_HF_REPO,
        local_dir=str(models_dir),
        local_dir_use_symlinks=False,
        token=token,
    )
    print("  Model download complete.")


# ---------------------------------------------------------------------------
# print_summary_table
# ---------------------------------------------------------------------------

def print_summary_table(results: list[dict]) -> None:
    """Print a consolidated train/validation summary for all evaluated models.

    Args:
        results: list of dicts with keys: model, split, accuracy, macro_f1
    """
    if not results:
        return
    headers = ["Model", "Split", "Accuracy", "Macro-F1"]
    rows = [
        [r["model"], r["split"], f"{r['accuracy']:.4f}", f"{r['macro_f1']:.4f}"]
        for r in results
    ]
    col_w = [max(len(str(x)) for x in [h] + [row[i] for row in rows])
             for i, h in enumerate(headers)]
    fmt = "  ".join(f"{{:<{w}}}" for w in col_w)
    sep = "  ".join("-" * w for w in col_w)
    width = sum(col_w) + 2 * (len(col_w) - 1)
    print(f"\n{'='*width}")
    print("  Summary — All Models")
    print(f"{'='*width}")
    print(fmt.format(*headers))
    print(sep)
    for row in rows:
        print(fmt.format(*row))
    print()


# ---------------------------------------------------------------------------
# DomainClassifier
# ---------------------------------------------------------------------------

def evaluate_domain_classifier(cfpb_csv: str | None, results: list[dict]) -> None:
    """Evaluate DomainClassifier on train sample and validation set.

    Recreates the exact 90/10 split used during training (seed=42).
    Skips gracefully when cfpb_csv is None.

    Args:
        cfpb_csv: path to CFPB complaints CSV, or None to skip
        results:  shared list to append summary rows to
    """
    print("\n" + "=" * 72)
    print("  DomainClassifier  (DistilBERT, 6-class)")
    print("=" * 72)

    # --- Training curve from Kaggle log (hardcoded) ---
    print("\n  Training curve (from Kaggle log):")
    curve_headers = ["Epoch", "Train loss range", "Val loss", "Notes"]
    curve_rows = [
        ["1", "0.8401 → 0.2807", "0.2768", ""],
        ["2", "0.2460 → 0.1955", "0.2720", "best checkpoint (load_best_model_at_end)"],
        ["3", "0.2129 → 0.1310", "0.3334", "overfitting — epoch 2 weights saved"],
    ]
    col_w = [max(len(str(x)) for x in [h] + [r[i] for r in curve_rows])
             for i, h in enumerate(curve_headers)]
    fmt = "  ".join(f"{{:<{w}}}" for w in col_w)
    sep = "  ".join("-" * w for w in col_w)
    print("  " + fmt.format(*curve_headers))
    print("  " + sep)
    for row in curve_rows:
        print("  " + fmt.format(*row))
    print("  Final train loss: 0.2402  |  train samples/sec: 37.12")

    if cfpb_csv is None:
        print("\n  [skipped] Pass --cfpb_csv <path> to evaluate on data splits.")
        return

    from datasets import concatenate_datasets
    from sklearn.metrics import accuracy_score, classification_report, f1_score
    from transformers import AutoModelForSequenceClassification, AutoTokenizer

    from src.classifier.train import _build_supplement, load_and_remap_cfpb
    from src.classifier.model import DOMAIN_LABELS

    print("\n  Loading data …")
    cfpb_ds = load_and_remap_cfpb(cfpb_csv, max_per_class=50_000)
    suppl_ds = _build_supplement(n_per_class=5_000)
    full_ds = concatenate_datasets([cfpb_ds, suppl_ds]).shuffle(seed=42)
    split = full_ds.train_test_split(test_size=0.1, seed=42)

    model_dir = "models/domain_classifier"
    print(f"  Loading checkpoint from {model_dir} …")
    tokenizer = AutoTokenizer.from_pretrained(model_dir)
    model = AutoModelForSequenceClassification.from_pretrained(model_dir)
    model.eval()
    device = torch.device(
        "cuda" if torch.cuda.is_available()
        else "mps" if torch.backends.mps.is_available()
        else "cpu"
    )
    model.to(device)

    def _predict_batch(texts: list[str]) -> list[int]:
        enc = tokenizer(texts, truncation=True, max_length=512,
                        padding=True, return_tensors="pt")
        enc = {k: v.to(device) for k, v in enc.items()}
        with torch.no_grad():
            logits = model(**enc).logits
        return logits.argmax(dim=-1).cpu().tolist()

    def _eval_split(ds, name: str, max_samples: int) -> None:
        if len(ds) > max_samples:
            ds = ds.select(range(max_samples))
        texts, labels = ds["text"], ds["labels"]
        preds: list[int] = []
        for i in range(0, len(texts), 64):
            preds.extend(_predict_batch(texts[i:i + 64]))
        acc = accuracy_score(labels, preds)
        mac_f1 = f1_score(labels, preds, average="macro", zero_division=0)
        print(f"\n  [{name}]  n={len(ds)}  accuracy={acc:.4f}  macro-F1={mac_f1:.4f}")
        report = classification_report(labels, preds,
                                       target_names=DOMAIN_LABELS, zero_division=0)
        for line in report.splitlines():
            print(f"    {line}")
        results.append({"model": "DomainClassifier", "split": name,
                        "accuracy": acc, "macro_f1": mac_f1})

    _eval_split(split["train"], "train", max_samples=5_000)
    _eval_split(split["test"],  "validation", max_samples=len(split["test"]))


# ---------------------------------------------------------------------------
# EvidenceNER
# ---------------------------------------------------------------------------

def _words_to_bio(sentence: str, entities: list[dict]) -> list[str]:
    """Convert a sentence + entity list to a BIO tag sequence over whitespace tokens.

    Args:
        sentence: raw complaint sentence string
        entities: list of {"text": str, "label": str} dicts

    Returns:
        list of BIO label strings aligned to sentence.split()
    """
    words = sentence.split()
    tags = ["O"] * len(words)
    for ent in entities:
        ent_words = ent["text"].split()
        label = ent["label"]
        # slide a window to find where entity words appear in sentence words
        for i in range(len(words) - len(ent_words) + 1):
            if words[i:i + len(ent_words)] == ent_words:
                tags[i] = f"B-{label}"
                for j in range(1, len(ent_words)):
                    tags[i + j] = f"I-{label}"
                break
    return tags


def _predict_bio_tags(sentence: str, model, tokenizer, id2label: dict,
                      device: torch.device) -> list[str]:
    """Run NER model on a single sentence and return word-level BIO tags.

    Args:
        sentence: raw string to tag
        model:    loaded token classification model
        tokenizer: matching tokenizer
        id2label:  id→BIO label mapping
        device:    torch device

    Returns:
        list of BIO label strings, one per whitespace token
    """
    words = sentence.split()
    enc = tokenizer(words, truncation=True, max_length=512,
                    is_split_into_words=True, return_tensors="pt")
    word_ids = tokenizer(words, truncation=True, max_length=512,
                         is_split_into_words=True).word_ids()
    enc = {k: v.to(device) for k, v in enc.items()}
    with torch.no_grad():
        logits = model(**enc).logits[0]
    pred_ids = logits.argmax(dim=-1).cpu().tolist()

    # First subword per word gets the predicted tag
    pred_tags: list[str] = []
    prev_word_id = None
    for tok_idx, word_id in enumerate(word_ids):
        if word_id is None or word_id == prev_word_id:
            prev_word_id = word_id
            continue
        prev_word_id = word_id
        pred_tags.append(id2label[pred_ids[tok_idx]])
    return pred_tags[:len(words)]


def evaluate_ner_synthetic(model, tokenizer, id2label: dict,
                           device: torch.device, results: list[dict]) -> None:
    """Evaluate EvidenceNER on synthetic train and validation splits.

    Recreates the 90/10 split from build_synthetic_dataset (seed=42).

    Args:
        model:    loaded token classification model
        tokenizer: matching tokenizer
        id2label:  id→BIO label mapping
        device:   torch device
        results:  shared list to append summary rows to
    """
    try:
        from seqeval.metrics import (
            accuracy_score, classification_report,
            f1_score, precision_score, recall_score,
        )
    except ImportError:
        print("  [error] seqeval not installed: pip install seqeval")
        return

    from src.ner.train import build_synthetic_dataset, _try_load_conll
    from datasets import concatenate_datasets

    print("\n  Building synthetic dataset …")
    synthetic_ds = build_synthetic_dataset(n_samples=4000)
    conll_ds = _try_load_conll()
    if conll_ds is not None:
        full_ds = concatenate_datasets([synthetic_ds, conll_ds]).shuffle(seed=42)
    else:
        full_ds = synthetic_ds
    split = full_ds.train_test_split(test_size=0.1, seed=42)

    def _eval_split(ds, name: str, max_samples: int) -> None:
        if len(ds) > max_samples:
            ds = ds.select(range(max_samples))
        true_seqs, pred_seqs = [], []
        for ex in ds:
            true_tags = [id2label[t] for t in ex["ner_tags"]]
            words = ex["words"]
            sentence = " ".join(words)
            pred_tags = _predict_bio_tags(sentence, model, tokenizer,
                                          id2label, device)
            n = min(len(true_tags), len(pred_tags))
            true_seqs.append(true_tags[:n])
            pred_seqs.append(pred_tags[:n])

        acc = accuracy_score(true_seqs, pred_seqs)
        prec = precision_score(true_seqs, pred_seqs, zero_division=0)
        rec = recall_score(true_seqs, pred_seqs, zero_division=0)
        f1 = f1_score(true_seqs, pred_seqs, zero_division=0)
        print(f"\n  [synthetic {name}]  n={len(ds)}")
        print(f"    accuracy={acc:.4f}  precision={prec:.4f}  "
              f"recall={rec:.4f}  F1={f1:.4f}")
        report = classification_report(true_seqs, pred_seqs, zero_division=0)
        for line in report.splitlines():
            print(f"    {line}")
        results.append({"model": "EvidenceNER (synthetic)",
                        "split": name, "accuracy": acc, "macro_f1": f1})

    _eval_split(split["train"], "train", max_samples=2_000)
    _eval_split(split["test"],  "validation", max_samples=len(split["test"]))


def evaluate_ner_real(model, tokenizer, id2label: dict,
                      device: torch.device, results: list[dict]) -> None:
    """Evaluate EvidenceNER on 40 real hand-verified complaint sentences.

    Loads data/eval/ner_real_complaints.json. Skips gracefully if missing.

    Args:
        model:    loaded token classification model
        tokenizer: matching tokenizer
        id2label:  id→BIO label mapping
        device:   torch device
        results:  shared list to append summary rows to
    """
    dataset_path = Path("data/eval/ner_real_complaints.json")
    if not dataset_path.exists():
        print(f"\n  [skipped] {dataset_path} not found — real complaint eval skipped.")
        return

    try:
        from seqeval.metrics import (
            accuracy_score, classification_report,
            f1_score, precision_score, recall_score,
        )
    except ImportError:
        print("  [error] seqeval not installed: pip install seqeval")
        return

    with open(dataset_path) as f:
        dataset = json.load(f)

    true_seqs, pred_seqs = [], []
    for item in dataset:
        sentence = item["sentence"]
        entities = item["entities"]
        true_tags = _words_to_bio(sentence, entities)
        pred_tags = _predict_bio_tags(sentence, model, tokenizer, id2label, device)
        n = min(len(true_tags), len(pred_tags))
        true_seqs.append(true_tags[:n])
        pred_seqs.append(pred_tags[:n])

    acc = accuracy_score(true_seqs, pred_seqs)
    prec = precision_score(true_seqs, pred_seqs, zero_division=0)
    rec = recall_score(true_seqs, pred_seqs, zero_division=0)
    f1 = f1_score(true_seqs, pred_seqs, zero_division=0)
    print(f"\n  [real complaints]  n={len(dataset)}")
    print(f"    accuracy={acc:.4f}  precision={prec:.4f}  "
          f"recall={rec:.4f}  F1={f1:.4f}")
    report = classification_report(true_seqs, pred_seqs, zero_division=0)
    for line in report.splitlines():
        print(f"    {line}")
    results.append({"model": "EvidenceNER (real)", "split": "validation",
                    "accuracy": acc, "macro_f1": f1})


def evaluate_ner(results: list[dict]) -> None:
    """Load EvidenceNER checkpoint and run synthetic + real complaint evaluation.

    Args:
        results: shared list to append summary rows to
    """
    print("\n" + "=" * 72)
    print("  EvidenceNER  (DistilBERT token classifier, BIO 13-label)")
    print("=" * 72)

    from transformers import AutoModelForTokenClassification, AutoTokenizer
    from src.ner.model import ID2LABEL

    model_dir = "models/evidence_ner"
    print(f"  Loading checkpoint from {model_dir} …")
    tokenizer = AutoTokenizer.from_pretrained(model_dir)
    model = AutoModelForTokenClassification.from_pretrained(model_dir)
    model.eval()
    device = torch.device(
        "cuda" if torch.cuda.is_available()
        else "mps" if torch.backends.mps.is_available()
        else "cpu"
    )
    model.to(device)

    evaluate_ner_synthetic(model, tokenizer, ID2LABEL, device, results)
    evaluate_ner_real(model, tokenizer, ID2LABEL, device, results)


# ---------------------------------------------------------------------------
# NextActionPredictor
# ---------------------------------------------------------------------------

def evaluate_next_action(results: list[dict]) -> None:
    """Evaluate NextActionPredictor on train (90%) and validation (10%) splits.

    Recreates 6000-sample dataset (seed=42), carves 90/10 split.
    Documents legal F1 = 0.00 as a known class-imbalance limitation.

    Args:
        results: shared list to append summary rows to
    """
    print("\n" + "=" * 72)
    print("  NextActionPredictor  (MLP 12→64→64→6)")
    print("=" * 72)

    from sklearn.metrics import accuracy_score, classification_report, f1_score

    from src.next_action.train import build_synthetic_dataset
    from src.next_action.model import ACTION_LABELS, GUIDE_MLP

    print("  Building synthetic dataset (n=6000, seed=42) …")
    X_list, y_list = build_synthetic_dataset(n_samples=6000, seed=42)
    X_all = torch.tensor(X_list, dtype=torch.float32)
    y_all = torch.tensor(y_list, dtype=torch.long)

    split_idx = int(len(X_all) * 0.9)
    X_train, X_val = X_all[:split_idx], X_all[split_idx:]
    y_train, y_val = y_all[:split_idx], y_all[split_idx:]

    model_path = "models/next_action/model.pt"
    print(f"  Loading checkpoint from {model_path} …")
    ckpt = torch.load(model_path, map_location="cpu", weights_only=True)
    mlp = GUIDE_MLP()
    mlp.load_state_dict(ckpt["state_dict"])
    mlp.eval()

    def _eval_split(X: torch.Tensor, y: torch.Tensor, name: str) -> None:
        with torch.no_grad():
            preds = mlp(X).argmax(dim=-1).numpy()
        truths = y.numpy()
        acc = accuracy_score(truths, preds)
        mac_f1 = f1_score(truths, preds, average="macro", zero_division=0)
        print(f"\n  [{name}]  n={len(y)}  accuracy={acc:.4f}  macro-F1={mac_f1:.4f}")
        report = classification_report(truths, preds,
                                       target_names=ACTION_LABELS, zero_division=0)
        for line in report.splitlines():
            print(f"    {line}")
        results.append({"model": "NextActionPredictor", "split": name,
                        "accuracy": acc, "macro_f1": mac_f1})

    _eval_split(X_train, y_train, "train")
    _eval_split(X_val, y_val, "validation")

    print("\n  NOTE: 'legal' class F1 = 0.00 is a known limitation.")
    print("  Cause: ~2.5% class frequency due to 20% coin-flip in label")
    print("  assignment. Model learns to never predict 'legal' to maximise")
    print("  overall accuracy. Fix: remove the coin-flip condition in train.py.")


# ---------------------------------------------------------------------------
# main
# ---------------------------------------------------------------------------

def main() -> None:
    """Parse CLI args, run selected model evaluations, print summary table."""
    p = argparse.ArgumentParser(description="Evaluate G.U.I.D.E. models")
    p.add_argument("--cfpb_csv", default=None,
                   help="Path to CFPB complaints CSV (required for DomainClassifier)")
    p.add_argument("--skip_classifier", action="store_true",
                   help="Skip DomainClassifier evaluation")
    p.add_argument("--skip_ner", action="store_true",
                   help="Skip EvidenceNER evaluation")
    p.add_argument("--skip_next_action", action="store_true",
                   help="Skip NextActionPredictor evaluation")
    args = p.parse_args()

    _ensure_models()

    results: list[dict] = []

    if not args.skip_classifier:
        evaluate_domain_classifier(args.cfpb_csv, results)

    if not args.skip_ner:
        evaluate_ner(results)

    if not args.skip_next_action:
        evaluate_next_action(results)

    print_summary_table(results)


if __name__ == "__main__":
    main()