from __future__ import annotations import argparse import os from dataclasses import dataclass import numpy as np from datasets import DatasetDict, load_dataset from transformers import ( AutoModelForSequenceClassification, AutoTokenizer, DataCollatorWithPadding, Trainer, TrainingArguments, ) @dataclass(frozen=True) class Config: dataset_id: str model_name: str hub_model_id: str output_dir: str max_length: int num_train_epochs: float per_device_train_batch_size: int per_device_eval_batch_size: int learning_rate: float weight_decay: float seed: int def _build_text(subject: str, body: str) -> str: subject = "" if subject is None else str(subject) body = "" if body is None else str(body) if subject and body: return f"Subject: {subject}\n\nBody: {body}" return subject or body def _parse_args() -> Config: p = argparse.ArgumentParser(description="Train & push email classifier to Hugging Face Hub.") p.add_argument("--dataset-id", default="FlowRank/labeled_emails") p.add_argument("--model-name", default="distilbert-base-uncased") p.add_argument("--hub-model-id", default="FlowRank/mailSort") p.add_argument("--output-dir", default="outputs") p.add_argument("--max-length", type=int, default=256) p.add_argument("--num-train-epochs", type=float, default=2) p.add_argument("--per-device-train-batch-size", type=int, default=16) p.add_argument("--per-device-eval-batch-size", type=int, default=32) p.add_argument("--learning-rate", type=float, default=2e-5) p.add_argument("--weight-decay", type=float, default=0.01) p.add_argument("--seed", type=int, default=42) a = p.parse_args() return Config( dataset_id=a.dataset_id, model_name=a.model_name, hub_model_id=a.hub_model_id, output_dir=a.output_dir, max_length=a.max_length, num_train_epochs=a.num_train_epochs, per_device_train_batch_size=a.per_device_train_batch_size, per_device_eval_batch_size=a.per_device_eval_batch_size, learning_rate=a.learning_rate, weight_decay=a.weight_decay, seed=a.seed, ) def _load_ds(dataset_id: str, seed: int) -> DatasetDict: ds = load_dataset(dataset_id) if "train" in ds and "test" in ds: return ds # already split # fallback: split if only a single split exists if "train" in ds and "test" not in ds: return ds["train"].train_test_split(test_size=0.1, seed=seed) # if weird structure, just return as-is and let Trainer fail loudly return ds def _prepare(ds: DatasetDict, tokenizer: AutoTokenizer, label2id: dict[str, int], max_length: int) -> DatasetDict: def preprocess(ex): text = _build_text(ex.get("subject"), ex.get("body")) out = tokenizer(text, truncation=True, max_length=max_length) out["labels"] = label2id[str(ex["label"])] return out cols_to_remove = [c for c in ["subject", "body", "label"] if c in ds["train"].column_names] return ds.map(preprocess, remove_columns=cols_to_remove) def _compute_metrics(eval_pred): logits, labels = eval_pred preds = np.argmax(logits, axis=-1) acc = (preds == labels).astype(np.float32).mean().item() return {"accuracy": acc} def main() -> int: cfg = _parse_args() ds = _load_ds(cfg.dataset_id, seed=cfg.seed) train_split = "train" if "train" in ds else list(ds.keys())[0] test_split = "test" if "test" in ds else ("validation" if "validation" in ds else None) if test_split is None: raise SystemExit(f"Dataset must have a test/validation split. Found: {list(ds.keys())}") tokenizer = AutoTokenizer.from_pretrained(cfg.model_name, use_fast=True) labels = sorted({str(x) for x in ds[train_split]["label"]}) label2id = {l: i for i, l in enumerate(labels)} id2label = {i: l for l, i in label2id.items()} encoded = _prepare(ds, tokenizer, label2id=label2id, max_length=cfg.max_length) data_collator = DataCollatorWithPadding(tokenizer=tokenizer) model = AutoModelForSequenceClassification.from_pretrained( cfg.model_name, num_labels=len(labels), label2id=label2id, id2label=id2label, ) push_to_hub = bool(os.getenv("HF_TOKEN")) or bool(os.getenv("HUGGINGFACE_HUB_TOKEN")) args = TrainingArguments( output_dir=cfg.output_dir, num_train_epochs=cfg.num_train_epochs, learning_rate=cfg.learning_rate, per_device_train_batch_size=cfg.per_device_train_batch_size, per_device_eval_batch_size=cfg.per_device_eval_batch_size, weight_decay=cfg.weight_decay, eval_strategy="epoch", save_strategy="epoch", logging_strategy="steps", logging_steps=50, load_best_model_at_end=True, metric_for_best_model="accuracy", seed=cfg.seed, report_to="none", push_to_hub=push_to_hub, hub_model_id=cfg.hub_model_id if push_to_hub else None, hub_strategy="end" if push_to_hub else "every_save", ) trainer = Trainer( model=model, args=args, train_dataset=encoded[train_split], eval_dataset=encoded[test_split], processing_class=tokenizer, data_collator=data_collator, compute_metrics=_compute_metrics, ) trainer.train() trainer.evaluate() trainer.save_model(cfg.output_dir) tokenizer.save_pretrained(cfg.output_dir) if args.push_to_hub: trainer.push_to_hub() return 0 if __name__ == "__main__": raise SystemExit(main())