File size: 3,573 Bytes
8153a62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
from __future__ import annotations

import argparse
from collections import Counter, defaultdict
from dataclasses import dataclass

from datasets import load_dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline


@dataclass(frozen=True)
class Config:
    dataset_id: str
    model_id_or_path: str
    subfolder: str | None
    split: str
    max_samples: int | None


def _build_text(subject: str, body: str) -> str:
    subject = "" if subject is None else str(subject)
    body = "" if body is None else str(body)
    if subject and body:
        return f"Subject: {subject}\n\nBody: {body}"
    return subject or body


def _parse_args() -> Config:
    p = argparse.ArgumentParser(description="Evaluate a model (local or Hub) against a HF dataset split.")
    p.add_argument("--dataset-id", default="FlowRank/labeled_emails")
    p.add_argument("--model", default="outputs", help="Local path OR Hugging Face repo id (e.g. FlowRank/mailSort).")
    p.add_argument("--subfolder", default=None, help="Optional subfolder (e.g. model).")
    p.add_argument("--split", default="test", help="Which split to evaluate (e.g. test).")
    p.add_argument("--max-samples", type=int, default=None, help="Limit evaluation to N samples.")
    a = p.parse_args()
    return Config(
        dataset_id=a.dataset_id,
        model_id_or_path=a.model,
        subfolder=a.subfolder,
        split=a.split,
        max_samples=a.max_samples,
    )


def main() -> int:
    cfg = _parse_args()

    ds = load_dataset(cfg.dataset_id)
    if cfg.split not in ds:
        raise SystemExit(f"Split '{cfg.split}' not found. Available: {list(ds.keys())}")

    rows = ds[cfg.split]
    if cfg.max_samples is not None:
        rows = rows.select(range(min(cfg.max_samples, len(rows))))

    kwargs = {}
    if cfg.subfolder:
        kwargs["subfolder"] = cfg.subfolder

    tokenizer = AutoTokenizer.from_pretrained(cfg.model_id_or_path, **kwargs)
    model = AutoModelForSequenceClassification.from_pretrained(cfg.model_id_or_path, **kwargs)
    clf = pipeline(
        "text-classification",
        model=model,
        tokenizer=tokenizer,
        truncation=True,
    )

    correct = 0
    total = 0
    per_label = Counter()
    per_label_ok = Counter()
    confusion = defaultdict(Counter)  # true -> pred -> count

    for ex in rows:
        text = _build_text(ex.get("subject"), ex.get("body"))
        true_label = str(ex["label"])
        pred = clf(text, top_k=1)[0]["label"]

        total += 1
        per_label[true_label] += 1
        confusion[true_label][pred] += 1

        if pred == true_label:
            correct += 1
            per_label_ok[true_label] += 1

    acc = correct / total if total else 0.0
    print(f"dataset={cfg.dataset_id} split={cfg.split} samples={total}")
    print(f"accuracy={acc:.4f} ({correct}/{total})")
    print("\nper-label accuracy:")
    for label in sorted(per_label.keys()):
        denom = per_label[label]
        num = per_label_ok[label]
        print(f"- {label}: {num}/{denom} = {num/denom:.4f}")

    # print top confusions per label (lightweight)
    print("\ncommon confusions (top-2 per true label):")
    for true_label in sorted(confusion.keys()):
        most = confusion[true_label].most_common(3)
        # skip perfect-only rows
        if len(most) == 1 and most[0][0] == true_label:
            continue
        top = ", ".join([f"{pred}:{cnt}" for pred, cnt in most])
        print(f"- {true_label}: {top}")

    return 0


if __name__ == "__main__":
    raise SystemExit(main())