mailSort / src /mailsort /train.py
enzofrnt's picture
feat(training): pipeline minimal train/test + artefacts HF
8153a62 unverified
from __future__ import annotations
import argparse
import os
from dataclasses import dataclass
import numpy as np
from datasets import DatasetDict, load_dataset
from transformers import (
AutoModelForSequenceClassification,
AutoTokenizer,
DataCollatorWithPadding,
Trainer,
TrainingArguments,
)
@dataclass(frozen=True)
class Config:
dataset_id: str
model_name: str
hub_model_id: str
output_dir: str
max_length: int
num_train_epochs: float
per_device_train_batch_size: int
per_device_eval_batch_size: int
learning_rate: float
weight_decay: float
seed: int
def _build_text(subject: str, body: str) -> str:
subject = "" if subject is None else str(subject)
body = "" if body is None else str(body)
if subject and body:
return f"Subject: {subject}\n\nBody: {body}"
return subject or body
def _parse_args() -> Config:
p = argparse.ArgumentParser(description="Train & push email classifier to Hugging Face Hub.")
p.add_argument("--dataset-id", default="FlowRank/labeled_emails")
p.add_argument("--model-name", default="distilbert-base-uncased")
p.add_argument("--hub-model-id", default="FlowRank/mailSort")
p.add_argument("--output-dir", default="outputs")
p.add_argument("--max-length", type=int, default=256)
p.add_argument("--num-train-epochs", type=float, default=2)
p.add_argument("--per-device-train-batch-size", type=int, default=16)
p.add_argument("--per-device-eval-batch-size", type=int, default=32)
p.add_argument("--learning-rate", type=float, default=2e-5)
p.add_argument("--weight-decay", type=float, default=0.01)
p.add_argument("--seed", type=int, default=42)
a = p.parse_args()
return Config(
dataset_id=a.dataset_id,
model_name=a.model_name,
hub_model_id=a.hub_model_id,
output_dir=a.output_dir,
max_length=a.max_length,
num_train_epochs=a.num_train_epochs,
per_device_train_batch_size=a.per_device_train_batch_size,
per_device_eval_batch_size=a.per_device_eval_batch_size,
learning_rate=a.learning_rate,
weight_decay=a.weight_decay,
seed=a.seed,
)
def _load_ds(dataset_id: str, seed: int) -> DatasetDict:
ds = load_dataset(dataset_id)
if "train" in ds and "test" in ds:
return ds # already split
# fallback: split if only a single split exists
if "train" in ds and "test" not in ds:
return ds["train"].train_test_split(test_size=0.1, seed=seed)
# if weird structure, just return as-is and let Trainer fail loudly
return ds
def _prepare(ds: DatasetDict, tokenizer: AutoTokenizer, label2id: dict[str, int], max_length: int) -> DatasetDict:
def preprocess(ex):
text = _build_text(ex.get("subject"), ex.get("body"))
out = tokenizer(text, truncation=True, max_length=max_length)
out["labels"] = label2id[str(ex["label"])]
return out
cols_to_remove = [c for c in ["subject", "body", "label"] if c in ds["train"].column_names]
return ds.map(preprocess, remove_columns=cols_to_remove)
def _compute_metrics(eval_pred):
logits, labels = eval_pred
preds = np.argmax(logits, axis=-1)
acc = (preds == labels).astype(np.float32).mean().item()
return {"accuracy": acc}
def main() -> int:
cfg = _parse_args()
ds = _load_ds(cfg.dataset_id, seed=cfg.seed)
train_split = "train" if "train" in ds else list(ds.keys())[0]
test_split = "test" if "test" in ds else ("validation" if "validation" in ds else None)
if test_split is None:
raise SystemExit(f"Dataset must have a test/validation split. Found: {list(ds.keys())}")
tokenizer = AutoTokenizer.from_pretrained(cfg.model_name, use_fast=True)
labels = sorted({str(x) for x in ds[train_split]["label"]})
label2id = {l: i for i, l in enumerate(labels)}
id2label = {i: l for l, i in label2id.items()}
encoded = _prepare(ds, tokenizer, label2id=label2id, max_length=cfg.max_length)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
model = AutoModelForSequenceClassification.from_pretrained(
cfg.model_name,
num_labels=len(labels),
label2id=label2id,
id2label=id2label,
)
push_to_hub = bool(os.getenv("HF_TOKEN")) or bool(os.getenv("HUGGINGFACE_HUB_TOKEN"))
args = TrainingArguments(
output_dir=cfg.output_dir,
num_train_epochs=cfg.num_train_epochs,
learning_rate=cfg.learning_rate,
per_device_train_batch_size=cfg.per_device_train_batch_size,
per_device_eval_batch_size=cfg.per_device_eval_batch_size,
weight_decay=cfg.weight_decay,
eval_strategy="epoch",
save_strategy="epoch",
logging_strategy="steps",
logging_steps=50,
load_best_model_at_end=True,
metric_for_best_model="accuracy",
seed=cfg.seed,
report_to="none",
push_to_hub=push_to_hub,
hub_model_id=cfg.hub_model_id if push_to_hub else None,
hub_strategy="end" if push_to_hub else "every_save",
)
trainer = Trainer(
model=model,
args=args,
train_dataset=encoded[train_split],
eval_dataset=encoded[test_split],
processing_class=tokenizer,
data_collator=data_collator,
compute_metrics=_compute_metrics,
)
trainer.train()
trainer.evaluate()
trainer.save_model(cfg.output_dir)
tokenizer.save_pretrained(cfg.output_dir)
if args.push_to_hub:
trainer.push_to_hub()
return 0
if __name__ == "__main__":
raise SystemExit(main())