from __future__ import annotations import argparse from collections import Counter, defaultdict from dataclasses import dataclass from datasets import load_dataset from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline @dataclass(frozen=True) class Config: dataset_id: str model_id_or_path: str subfolder: str | None split: str max_samples: int | None def _build_text(subject: str, body: str) -> str: subject = "" if subject is None else str(subject) body = "" if body is None else str(body) if subject and body: return f"Subject: {subject}\n\nBody: {body}" return subject or body def _parse_args() -> Config: p = argparse.ArgumentParser(description="Evaluate a model (local or Hub) against a HF dataset split.") p.add_argument("--dataset-id", default="FlowRank/labeled_emails") p.add_argument("--model", default="outputs", help="Local path OR Hugging Face repo id (e.g. FlowRank/mailSort).") p.add_argument("--subfolder", default=None, help="Optional subfolder (e.g. model).") p.add_argument("--split", default="test", help="Which split to evaluate (e.g. test).") p.add_argument("--max-samples", type=int, default=None, help="Limit evaluation to N samples.") a = p.parse_args() return Config( dataset_id=a.dataset_id, model_id_or_path=a.model, subfolder=a.subfolder, split=a.split, max_samples=a.max_samples, ) def main() -> int: cfg = _parse_args() ds = load_dataset(cfg.dataset_id) if cfg.split not in ds: raise SystemExit(f"Split '{cfg.split}' not found. Available: {list(ds.keys())}") rows = ds[cfg.split] if cfg.max_samples is not None: rows = rows.select(range(min(cfg.max_samples, len(rows)))) kwargs = {} if cfg.subfolder: kwargs["subfolder"] = cfg.subfolder tokenizer = AutoTokenizer.from_pretrained(cfg.model_id_or_path, **kwargs) model = AutoModelForSequenceClassification.from_pretrained(cfg.model_id_or_path, **kwargs) clf = pipeline( "text-classification", model=model, tokenizer=tokenizer, truncation=True, ) correct = 0 total = 0 per_label = Counter() per_label_ok = Counter() confusion = defaultdict(Counter) # true -> pred -> count for ex in rows: text = _build_text(ex.get("subject"), ex.get("body")) true_label = str(ex["label"]) pred = clf(text, top_k=1)[0]["label"] total += 1 per_label[true_label] += 1 confusion[true_label][pred] += 1 if pred == true_label: correct += 1 per_label_ok[true_label] += 1 acc = correct / total if total else 0.0 print(f"dataset={cfg.dataset_id} split={cfg.split} samples={total}") print(f"accuracy={acc:.4f} ({correct}/{total})") print("\nper-label accuracy:") for label in sorted(per_label.keys()): denom = per_label[label] num = per_label_ok[label] print(f"- {label}: {num}/{denom} = {num/denom:.4f}") # print top confusions per label (lightweight) print("\ncommon confusions (top-2 per true label):") for true_label in sorted(confusion.keys()): most = confusion[true_label].most_common(3) # skip perfect-only rows if len(most) == 1 and most[0][0] == true_label: continue top = ", ".join([f"{pred}:{cnt}" for pred, cnt in most]) print(f"- {true_label}: {top}") return 0 if __name__ == "__main__": raise SystemExit(main())