mailSort / src /mailsort /eval.py
enzofrnt's picture
feat(training): pipeline minimal train/test + artefacts HF
8153a62 unverified
from __future__ import annotations
import argparse
from collections import Counter, defaultdict
from dataclasses import dataclass
from datasets import load_dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
@dataclass(frozen=True)
class Config:
dataset_id: str
model_id_or_path: str
subfolder: str | None
split: str
max_samples: int | None
def _build_text(subject: str, body: str) -> str:
subject = "" if subject is None else str(subject)
body = "" if body is None else str(body)
if subject and body:
return f"Subject: {subject}\n\nBody: {body}"
return subject or body
def _parse_args() -> Config:
p = argparse.ArgumentParser(description="Evaluate a model (local or Hub) against a HF dataset split.")
p.add_argument("--dataset-id", default="FlowRank/labeled_emails")
p.add_argument("--model", default="outputs", help="Local path OR Hugging Face repo id (e.g. FlowRank/mailSort).")
p.add_argument("--subfolder", default=None, help="Optional subfolder (e.g. model).")
p.add_argument("--split", default="test", help="Which split to evaluate (e.g. test).")
p.add_argument("--max-samples", type=int, default=None, help="Limit evaluation to N samples.")
a = p.parse_args()
return Config(
dataset_id=a.dataset_id,
model_id_or_path=a.model,
subfolder=a.subfolder,
split=a.split,
max_samples=a.max_samples,
)
def main() -> int:
cfg = _parse_args()
ds = load_dataset(cfg.dataset_id)
if cfg.split not in ds:
raise SystemExit(f"Split '{cfg.split}' not found. Available: {list(ds.keys())}")
rows = ds[cfg.split]
if cfg.max_samples is not None:
rows = rows.select(range(min(cfg.max_samples, len(rows))))
kwargs = {}
if cfg.subfolder:
kwargs["subfolder"] = cfg.subfolder
tokenizer = AutoTokenizer.from_pretrained(cfg.model_id_or_path, **kwargs)
model = AutoModelForSequenceClassification.from_pretrained(cfg.model_id_or_path, **kwargs)
clf = pipeline(
"text-classification",
model=model,
tokenizer=tokenizer,
truncation=True,
)
correct = 0
total = 0
per_label = Counter()
per_label_ok = Counter()
confusion = defaultdict(Counter) # true -> pred -> count
for ex in rows:
text = _build_text(ex.get("subject"), ex.get("body"))
true_label = str(ex["label"])
pred = clf(text, top_k=1)[0]["label"]
total += 1
per_label[true_label] += 1
confusion[true_label][pred] += 1
if pred == true_label:
correct += 1
per_label_ok[true_label] += 1
acc = correct / total if total else 0.0
print(f"dataset={cfg.dataset_id} split={cfg.split} samples={total}")
print(f"accuracy={acc:.4f} ({correct}/{total})")
print("\nper-label accuracy:")
for label in sorted(per_label.keys()):
denom = per_label[label]
num = per_label_ok[label]
print(f"- {label}: {num}/{denom} = {num/denom:.4f}")
# print top confusions per label (lightweight)
print("\ncommon confusions (top-2 per true label):")
for true_label in sorted(confusion.keys()):
most = confusion[true_label].most_common(3)
# skip perfect-only rows
if len(most) == 1 and most[0][0] == true_label:
continue
top = ", ".join([f"{pred}:{cnt}" for pred, cnt in most])
print(f"- {true_label}: {top}")
return 0
if __name__ == "__main__":
raise SystemExit(main())