File size: 5,638 Bytes
8153a62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
from __future__ import annotations

import argparse
import os
from dataclasses import dataclass

import numpy as np
from datasets import DatasetDict, load_dataset
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    DataCollatorWithPadding,
    Trainer,
    TrainingArguments,
)


@dataclass(frozen=True)
class Config:
    dataset_id: str
    model_name: str
    hub_model_id: str
    output_dir: str
    max_length: int
    num_train_epochs: float
    per_device_train_batch_size: int
    per_device_eval_batch_size: int
    learning_rate: float
    weight_decay: float
    seed: int


def _build_text(subject: str, body: str) -> str:
    subject = "" if subject is None else str(subject)
    body = "" if body is None else str(body)
    if subject and body:
        return f"Subject: {subject}\n\nBody: {body}"
    return subject or body


def _parse_args() -> Config:
    p = argparse.ArgumentParser(description="Train & push email classifier to Hugging Face Hub.")
    p.add_argument("--dataset-id", default="FlowRank/labeled_emails")
    p.add_argument("--model-name", default="distilbert-base-uncased")
    p.add_argument("--hub-model-id", default="FlowRank/mailSort")
    p.add_argument("--output-dir", default="outputs")
    p.add_argument("--max-length", type=int, default=256)
    p.add_argument("--num-train-epochs", type=float, default=2)
    p.add_argument("--per-device-train-batch-size", type=int, default=16)
    p.add_argument("--per-device-eval-batch-size", type=int, default=32)
    p.add_argument("--learning-rate", type=float, default=2e-5)
    p.add_argument("--weight-decay", type=float, default=0.01)
    p.add_argument("--seed", type=int, default=42)
    a = p.parse_args()
    return Config(
        dataset_id=a.dataset_id,
        model_name=a.model_name,
        hub_model_id=a.hub_model_id,
        output_dir=a.output_dir,
        max_length=a.max_length,
        num_train_epochs=a.num_train_epochs,
        per_device_train_batch_size=a.per_device_train_batch_size,
        per_device_eval_batch_size=a.per_device_eval_batch_size,
        learning_rate=a.learning_rate,
        weight_decay=a.weight_decay,
        seed=a.seed,
    )


def _load_ds(dataset_id: str, seed: int) -> DatasetDict:
    ds = load_dataset(dataset_id)
    if "train" in ds and "test" in ds:
        return ds  # already split
    # fallback: split if only a single split exists
    if "train" in ds and "test" not in ds:
        return ds["train"].train_test_split(test_size=0.1, seed=seed)
    # if weird structure, just return as-is and let Trainer fail loudly
    return ds


def _prepare(ds: DatasetDict, tokenizer: AutoTokenizer, label2id: dict[str, int], max_length: int) -> DatasetDict:
    def preprocess(ex):
        text = _build_text(ex.get("subject"), ex.get("body"))
        out = tokenizer(text, truncation=True, max_length=max_length)
        out["labels"] = label2id[str(ex["label"])]
        return out

    cols_to_remove = [c for c in ["subject", "body", "label"] if c in ds["train"].column_names]
    return ds.map(preprocess, remove_columns=cols_to_remove)


def _compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=-1)
    acc = (preds == labels).astype(np.float32).mean().item()
    return {"accuracy": acc}


def main() -> int:
    cfg = _parse_args()

    ds = _load_ds(cfg.dataset_id, seed=cfg.seed)
    train_split = "train" if "train" in ds else list(ds.keys())[0]
    test_split = "test" if "test" in ds else ("validation" if "validation" in ds else None)

    if test_split is None:
        raise SystemExit(f"Dataset must have a test/validation split. Found: {list(ds.keys())}")

    tokenizer = AutoTokenizer.from_pretrained(cfg.model_name, use_fast=True)

    labels = sorted({str(x) for x in ds[train_split]["label"]})
    label2id = {l: i for i, l in enumerate(labels)}
    id2label = {i: l for l, i in label2id.items()}

    encoded = _prepare(ds, tokenizer, label2id=label2id, max_length=cfg.max_length)
    data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

    model = AutoModelForSequenceClassification.from_pretrained(
        cfg.model_name,
        num_labels=len(labels),
        label2id=label2id,
        id2label=id2label,
    )

    push_to_hub = bool(os.getenv("HF_TOKEN")) or bool(os.getenv("HUGGINGFACE_HUB_TOKEN"))

    args = TrainingArguments(
        output_dir=cfg.output_dir,
        num_train_epochs=cfg.num_train_epochs,
        learning_rate=cfg.learning_rate,
        per_device_train_batch_size=cfg.per_device_train_batch_size,
        per_device_eval_batch_size=cfg.per_device_eval_batch_size,
        weight_decay=cfg.weight_decay,
        eval_strategy="epoch",
        save_strategy="epoch",
        logging_strategy="steps",
        logging_steps=50,
        load_best_model_at_end=True,
        metric_for_best_model="accuracy",
        seed=cfg.seed,
        report_to="none",
        push_to_hub=push_to_hub,
        hub_model_id=cfg.hub_model_id if push_to_hub else None,
        hub_strategy="end" if push_to_hub else "every_save",
    )

    trainer = Trainer(
        model=model,
        args=args,
        train_dataset=encoded[train_split],
        eval_dataset=encoded[test_split],
        processing_class=tokenizer,
        data_collator=data_collator,
        compute_metrics=_compute_metrics,
    )

    trainer.train()
    trainer.evaluate()

    trainer.save_model(cfg.output_dir)
    tokenizer.save_pretrained(cfg.output_dir)

    if args.push_to_hub:
        trainer.push_to_hub()

    return 0


if __name__ == "__main__":
    raise SystemExit(main())