enzofrnt HilarionLefuneste Cursor commited on
Commit
8153a62
·
unverified ·
1 Parent(s): 60c9bc5

feat(training): pipeline minimal train/test + artefacts HF

Browse files

- Entraîne un classifieur à partir de FlowRank/labeled_emails
- Ajoute scripts d’évaluation et de préparation des artefacts sous model/
- Documente l’usage (train/test/publish) dans README

Co-authored-by: Hilarion Lefuneste <hilarionlefuneste@tutamail.com>
Co-authored-by: Cursor <cursoragent@cursor.com>

.gitignore ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .venv/
2
+ __pycache__/
3
+ *.pyc
4
+ .DS_Store
5
+
6
+ # training outputs (local only)
7
+ outputs/
8
+ outputs_smoke/
9
+
10
+ # HF caches
11
+ .cache/
12
+ **/.cache/
13
+
14
+ # build metadata
15
+ *.egg-info/
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.11
README.md ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ pipeline_tag: text-classification
4
+ tags:
5
+ - email
6
+ - text-classification
7
+ language:
8
+ - en
9
+ ---
10
+
11
+ ## mailSort
12
+
13
+ Repo minimal (Python) pour **entraîner, évaluer et publier** un modèle de classification multi-classes d’e-mails à partir du dataset Hugging Face [`FlowRank/labeled_emails`](https://huggingface.co/datasets/FlowRank/labeled_emails).
14
+
15
+ Le script principal est `mailsort.train` (Transformers `Trainer`).
16
+
17
+ ### Prérequis
18
+
19
+ - Python géré par `uv` (ce repo est prévu pour être lancé avec `uv run`)
20
+ - (Optionnel) un GPU CUDA si tu veux accélérer l’entraînement
21
+ - (Optionnel) un token HF si tu veux publier sur le Hub
22
+
23
+ ### Installation
24
+
25
+ `uv` installe/synchronise automatiquement les dépendances la première fois.
26
+
27
+ ```bash
28
+ uv sync
29
+ ```
30
+
31
+ ### Entraîner + évaluer (train + test)
32
+
33
+ Par défaut, le script charge le dataset **depuis le Hub** et utilise ses splits `train` et `test`.
34
+ L’évaluation se fait à chaque epoch, puis une évaluation finale est exécutée à la fin.
35
+ Les artefacts (modèle, tokenizer) sont sauvegardés dans `outputs/` (ou le dossier passé via `--output-dir`).
36
+
37
+ ```bash
38
+ uv run python -m mailsort.train \
39
+ --dataset-id FlowRank/labeled_emails \
40
+ --model-name distilbert-base-uncased \
41
+ --hub-model-id FlowRank/mailSort \
42
+ --num-train-epochs 2
43
+ ```
44
+
45
+ ### Tester / évaluer uniquement
46
+
47
+ Le script n’a pas (encore) de mode “eval-only”.
48
+ Le **minimum** pour faire uniquement une passe rapide est de mettre `--num-train-epochs 0` (ce qui évite l’entraînement) et de garder la phase `evaluate`.
49
+
50
+ ```bash
51
+ uv run python -m mailsort.train --num-train-epochs 0
52
+ ```
53
+
54
+ ### Évaluer sur le split `test` du dataset (recommandé)
55
+
56
+ Après ton entraînement dans `outputs/`, tu peux évaluer proprement sur le **split `test`** de `FlowRank/labeled_emails` :
57
+
58
+ ```bash
59
+ uv run python -m mailsort.eval --model outputs --dataset-id FlowRank/labeled_emails --split test
60
+ ```
61
+
62
+ (Optionnel) Pour un test rapide :
63
+
64
+ ```bash
65
+ uv run python -m mailsort.eval --model outputs --split test --max-samples 200
66
+ ```
67
+
68
+ ### Publier sur le Hub (FlowRank/mailSort)
69
+
70
+ Le push se fait automatiquement si la variable d’environnement `HF_TOKEN` (ou `HUGGINGFACE_HUB_TOKEN`) est définie.
71
+
72
+ ```bash
73
+ export HF_TOKEN="..."
74
+ uv run python -m mailsort.train --hub-model-id FlowRank/mailSort
75
+ ```
76
+
77
+ ### Publier via Git (README à la racine + artefacts dans `model/`)
78
+
79
+ Pour avoir un repo Hugging Face “complet” (doc + poids) tout en gardant une structure propre, on met :
80
+
81
+ - `README.md` à la racine (documentation + model card)
82
+ - les artefacts (config, poids, tokenizer) dans `model/`
83
+
84
+ Hugging Face pourra charger le modèle via `subfolder="model"`.
85
+
86
+ 1) Préparer le dossier `model/` à partir de `outputs/` :
87
+
88
+ ```bash
89
+ uv run python -m mailsort.prepare_model --outputs-dir outputs --model-dir model
90
+ ```
91
+
92
+ 2) Commit + push vers le repo Hugging Face `FlowRank/mailSort` :
93
+
94
+ ```bash
95
+ git add README.md model
96
+ git commit -m "Add model artifacts under model/ + docs"
97
+ git push
98
+ ```
99
+
100
+ ### Inférence (utiliser le modèle publié)
101
+
102
+ ```python
103
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
104
+
105
+ tok = AutoTokenizer.from_pretrained("FlowRank/mailSort", subfolder="model")
106
+ model = AutoModelForSequenceClassification.from_pretrained("FlowRank/mailSort", subfolder="model")
107
+ clf = pipeline("text-classification", model=model, tokenizer=tok, truncation=True)
108
+ text = "Subject: Insurance claim\n\nBody: Hello, I need to update my policy..."
109
+ print(clf(text))
110
+ ```
111
+
main.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from mailsort.train import main
4
+
5
+
6
+ if __name__ == "__main__":
7
+ raise SystemExit(main())
model/config.json ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation": "gelu",
3
+ "architectures": [
4
+ "DistilBertForSequenceClassification"
5
+ ],
6
+ "attention_dropout": 0.1,
7
+ "bos_token_id": null,
8
+ "dim": 768,
9
+ "dropout": 0.1,
10
+ "dtype": "float32",
11
+ "eos_token_id": null,
12
+ "hidden_dim": 3072,
13
+ "id2label": {
14
+ "0": "family",
15
+ "1": "finance",
16
+ "2": "games",
17
+ "3": "human resources",
18
+ "4": "medical",
19
+ "5": "pets",
20
+ "6": "school",
21
+ "7": "software engineering",
22
+ "8": "sport",
23
+ "9": "work/airbus"
24
+ },
25
+ "initializer_range": 0.02,
26
+ "label2id": {
27
+ "family": 0,
28
+ "finance": 1,
29
+ "games": 2,
30
+ "human resources": 3,
31
+ "medical": 4,
32
+ "pets": 5,
33
+ "school": 6,
34
+ "software engineering": 7,
35
+ "sport": 8,
36
+ "work/airbus": 9
37
+ },
38
+ "max_position_embeddings": 512,
39
+ "model_type": "distilbert",
40
+ "n_heads": 12,
41
+ "n_layers": 6,
42
+ "pad_token_id": 0,
43
+ "problem_type": "single_label_classification",
44
+ "qa_dropout": 0.1,
45
+ "seq_classif_dropout": 0.2,
46
+ "sinusoidal_pos_embds": false,
47
+ "tie_weights_": true,
48
+ "tie_word_embeddings": true,
49
+ "transformers_version": "5.8.0",
50
+ "use_cache": false,
51
+ "vocab_size": 30522
52
+ }
model/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e88532e8b0d0583f7a96aa686f19f0ab2256ee741f6b55773cbb8f36f520c276
3
+ size 267857176
model/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
model/tokenizer_config.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "backend": "tokenizers",
3
+ "cls_token": "[CLS]",
4
+ "do_lower_case": true,
5
+ "is_local": false,
6
+ "local_files_only": false,
7
+ "mask_token": "[MASK]",
8
+ "model_max_length": 512,
9
+ "pad_token": "[PAD]",
10
+ "sep_token": "[SEP]",
11
+ "strip_accents": null,
12
+ "tokenize_chinese_chars": true,
13
+ "tokenizer_class": "BertTokenizer",
14
+ "unk_token": "[UNK]"
15
+ }
model/training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e5f07cdfe1546cecc7b90ed8a2b923d1de3ed2925d0c1de004f011b09e643764
3
+ size 5265
pyproject.toml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "mailsort"
3
+ version = "0.1.0"
4
+ description = "Add your description here"
5
+ requires-python = ">=3.11"
6
+ dependencies = [
7
+ "accelerate>=1.13.0",
8
+ "datasets>=4.8.5",
9
+ "huggingface-hub>=1.13.0",
10
+ "safetensors>=0.7.0",
11
+ "torch>=2.11.0",
12
+ "transformers>=5.8.0",
13
+ ]
14
+
15
+ [project.scripts]
16
+ mailsort-train = "mailsort.train:main"
17
+ mailsort-eval = "mailsort.eval:main"
18
+ mailsort-prepare-model = "mailsort.prepare_model:main"
19
+
20
+ [tool.uv]
21
+ package = true
src/mailsort/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ __all__ = []
2
+
src/mailsort/eval.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ from collections import Counter, defaultdict
5
+ from dataclasses import dataclass
6
+
7
+ from datasets import load_dataset
8
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
9
+
10
+
11
+ @dataclass(frozen=True)
12
+ class Config:
13
+ dataset_id: str
14
+ model_id_or_path: str
15
+ subfolder: str | None
16
+ split: str
17
+ max_samples: int | None
18
+
19
+
20
+ def _build_text(subject: str, body: str) -> str:
21
+ subject = "" if subject is None else str(subject)
22
+ body = "" if body is None else str(body)
23
+ if subject and body:
24
+ return f"Subject: {subject}\n\nBody: {body}"
25
+ return subject or body
26
+
27
+
28
+ def _parse_args() -> Config:
29
+ p = argparse.ArgumentParser(description="Evaluate a model (local or Hub) against a HF dataset split.")
30
+ p.add_argument("--dataset-id", default="FlowRank/labeled_emails")
31
+ p.add_argument("--model", default="outputs", help="Local path OR Hugging Face repo id (e.g. FlowRank/mailSort).")
32
+ p.add_argument("--subfolder", default=None, help="Optional subfolder (e.g. model).")
33
+ p.add_argument("--split", default="test", help="Which split to evaluate (e.g. test).")
34
+ p.add_argument("--max-samples", type=int, default=None, help="Limit evaluation to N samples.")
35
+ a = p.parse_args()
36
+ return Config(
37
+ dataset_id=a.dataset_id,
38
+ model_id_or_path=a.model,
39
+ subfolder=a.subfolder,
40
+ split=a.split,
41
+ max_samples=a.max_samples,
42
+ )
43
+
44
+
45
+ def main() -> int:
46
+ cfg = _parse_args()
47
+
48
+ ds = load_dataset(cfg.dataset_id)
49
+ if cfg.split not in ds:
50
+ raise SystemExit(f"Split '{cfg.split}' not found. Available: {list(ds.keys())}")
51
+
52
+ rows = ds[cfg.split]
53
+ if cfg.max_samples is not None:
54
+ rows = rows.select(range(min(cfg.max_samples, len(rows))))
55
+
56
+ kwargs = {}
57
+ if cfg.subfolder:
58
+ kwargs["subfolder"] = cfg.subfolder
59
+
60
+ tokenizer = AutoTokenizer.from_pretrained(cfg.model_id_or_path, **kwargs)
61
+ model = AutoModelForSequenceClassification.from_pretrained(cfg.model_id_or_path, **kwargs)
62
+ clf = pipeline(
63
+ "text-classification",
64
+ model=model,
65
+ tokenizer=tokenizer,
66
+ truncation=True,
67
+ )
68
+
69
+ correct = 0
70
+ total = 0
71
+ per_label = Counter()
72
+ per_label_ok = Counter()
73
+ confusion = defaultdict(Counter) # true -> pred -> count
74
+
75
+ for ex in rows:
76
+ text = _build_text(ex.get("subject"), ex.get("body"))
77
+ true_label = str(ex["label"])
78
+ pred = clf(text, top_k=1)[0]["label"]
79
+
80
+ total += 1
81
+ per_label[true_label] += 1
82
+ confusion[true_label][pred] += 1
83
+
84
+ if pred == true_label:
85
+ correct += 1
86
+ per_label_ok[true_label] += 1
87
+
88
+ acc = correct / total if total else 0.0
89
+ print(f"dataset={cfg.dataset_id} split={cfg.split} samples={total}")
90
+ print(f"accuracy={acc:.4f} ({correct}/{total})")
91
+ print("\nper-label accuracy:")
92
+ for label in sorted(per_label.keys()):
93
+ denom = per_label[label]
94
+ num = per_label_ok[label]
95
+ print(f"- {label}: {num}/{denom} = {num/denom:.4f}")
96
+
97
+ # print top confusions per label (lightweight)
98
+ print("\ncommon confusions (top-2 per true label):")
99
+ for true_label in sorted(confusion.keys()):
100
+ most = confusion[true_label].most_common(3)
101
+ # skip perfect-only rows
102
+ if len(most) == 1 and most[0][0] == true_label:
103
+ continue
104
+ top = ", ".join([f"{pred}:{cnt}" for pred, cnt in most])
105
+ print(f"- {true_label}: {top}")
106
+
107
+ return 0
108
+
109
+
110
+ if __name__ == "__main__":
111
+ raise SystemExit(main())
112
+
src/mailsort/prepare_model.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import shutil
5
+ from dataclasses import dataclass
6
+ from pathlib import Path
7
+
8
+
9
+ @dataclass(frozen=True)
10
+ class Config:
11
+ outputs_dir: Path
12
+ model_dir: Path
13
+
14
+
15
+ def _parse_args() -> Config:
16
+ p = argparse.ArgumentParser(description="Prepare model/ folder from an outputs/ training directory.")
17
+ p.add_argument("--outputs-dir", default="outputs", help="Training output directory (from mailsort.train).")
18
+ p.add_argument("--model-dir", default="model", help="Target folder to commit/push to Hugging Face.")
19
+ a = p.parse_args()
20
+ return Config(outputs_dir=Path(a.outputs_dir), model_dir=Path(a.model_dir))
21
+
22
+
23
+ def main() -> int:
24
+ cfg = _parse_args()
25
+
26
+ if not cfg.outputs_dir.exists():
27
+ raise SystemExit(f"outputs-dir not found: {cfg.outputs_dir}")
28
+
29
+ cfg.model_dir.mkdir(parents=True, exist_ok=True)
30
+
31
+ # clean target (keep it explicit and predictable)
32
+ for p in cfg.model_dir.iterdir():
33
+ if p.is_dir():
34
+ shutil.rmtree(p)
35
+ else:
36
+ p.unlink()
37
+
38
+ # Copy only final artifacts (root files), ignore trainer checkpoints.
39
+ for p in cfg.outputs_dir.iterdir():
40
+ if p.is_dir():
41
+ # ignore checkpoint-* dirs
42
+ continue
43
+ shutil.copy2(p, cfg.model_dir / p.name)
44
+
45
+ # sanity: expected minimum files
46
+ expected_any = [
47
+ "config.json",
48
+ "tokenizer.json",
49
+ "tokenizer_config.json",
50
+ ]
51
+ missing = [n for n in expected_any if not (cfg.model_dir / n).exists()]
52
+ if missing:
53
+ raise SystemExit(f"Missing expected files in {cfg.model_dir}: {missing}")
54
+
55
+ print(f"Prepared {cfg.model_dir} from {cfg.outputs_dir}")
56
+ return 0
57
+
58
+
59
+ if __name__ == "__main__":
60
+ raise SystemExit(main())
61
+
src/mailsort/train.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import os
5
+ from dataclasses import dataclass
6
+
7
+ import numpy as np
8
+ from datasets import DatasetDict, load_dataset
9
+ from transformers import (
10
+ AutoModelForSequenceClassification,
11
+ AutoTokenizer,
12
+ DataCollatorWithPadding,
13
+ Trainer,
14
+ TrainingArguments,
15
+ )
16
+
17
+
18
+ @dataclass(frozen=True)
19
+ class Config:
20
+ dataset_id: str
21
+ model_name: str
22
+ hub_model_id: str
23
+ output_dir: str
24
+ max_length: int
25
+ num_train_epochs: float
26
+ per_device_train_batch_size: int
27
+ per_device_eval_batch_size: int
28
+ learning_rate: float
29
+ weight_decay: float
30
+ seed: int
31
+
32
+
33
+ def _build_text(subject: str, body: str) -> str:
34
+ subject = "" if subject is None else str(subject)
35
+ body = "" if body is None else str(body)
36
+ if subject and body:
37
+ return f"Subject: {subject}\n\nBody: {body}"
38
+ return subject or body
39
+
40
+
41
+ def _parse_args() -> Config:
42
+ p = argparse.ArgumentParser(description="Train & push email classifier to Hugging Face Hub.")
43
+ p.add_argument("--dataset-id", default="FlowRank/labeled_emails")
44
+ p.add_argument("--model-name", default="distilbert-base-uncased")
45
+ p.add_argument("--hub-model-id", default="FlowRank/mailSort")
46
+ p.add_argument("--output-dir", default="outputs")
47
+ p.add_argument("--max-length", type=int, default=256)
48
+ p.add_argument("--num-train-epochs", type=float, default=2)
49
+ p.add_argument("--per-device-train-batch-size", type=int, default=16)
50
+ p.add_argument("--per-device-eval-batch-size", type=int, default=32)
51
+ p.add_argument("--learning-rate", type=float, default=2e-5)
52
+ p.add_argument("--weight-decay", type=float, default=0.01)
53
+ p.add_argument("--seed", type=int, default=42)
54
+ a = p.parse_args()
55
+ return Config(
56
+ dataset_id=a.dataset_id,
57
+ model_name=a.model_name,
58
+ hub_model_id=a.hub_model_id,
59
+ output_dir=a.output_dir,
60
+ max_length=a.max_length,
61
+ num_train_epochs=a.num_train_epochs,
62
+ per_device_train_batch_size=a.per_device_train_batch_size,
63
+ per_device_eval_batch_size=a.per_device_eval_batch_size,
64
+ learning_rate=a.learning_rate,
65
+ weight_decay=a.weight_decay,
66
+ seed=a.seed,
67
+ )
68
+
69
+
70
+ def _load_ds(dataset_id: str, seed: int) -> DatasetDict:
71
+ ds = load_dataset(dataset_id)
72
+ if "train" in ds and "test" in ds:
73
+ return ds # already split
74
+ # fallback: split if only a single split exists
75
+ if "train" in ds and "test" not in ds:
76
+ return ds["train"].train_test_split(test_size=0.1, seed=seed)
77
+ # if weird structure, just return as-is and let Trainer fail loudly
78
+ return ds
79
+
80
+
81
+ def _prepare(ds: DatasetDict, tokenizer: AutoTokenizer, label2id: dict[str, int], max_length: int) -> DatasetDict:
82
+ def preprocess(ex):
83
+ text = _build_text(ex.get("subject"), ex.get("body"))
84
+ out = tokenizer(text, truncation=True, max_length=max_length)
85
+ out["labels"] = label2id[str(ex["label"])]
86
+ return out
87
+
88
+ cols_to_remove = [c for c in ["subject", "body", "label"] if c in ds["train"].column_names]
89
+ return ds.map(preprocess, remove_columns=cols_to_remove)
90
+
91
+
92
+ def _compute_metrics(eval_pred):
93
+ logits, labels = eval_pred
94
+ preds = np.argmax(logits, axis=-1)
95
+ acc = (preds == labels).astype(np.float32).mean().item()
96
+ return {"accuracy": acc}
97
+
98
+
99
+ def main() -> int:
100
+ cfg = _parse_args()
101
+
102
+ ds = _load_ds(cfg.dataset_id, seed=cfg.seed)
103
+ train_split = "train" if "train" in ds else list(ds.keys())[0]
104
+ test_split = "test" if "test" in ds else ("validation" if "validation" in ds else None)
105
+
106
+ if test_split is None:
107
+ raise SystemExit(f"Dataset must have a test/validation split. Found: {list(ds.keys())}")
108
+
109
+ tokenizer = AutoTokenizer.from_pretrained(cfg.model_name, use_fast=True)
110
+
111
+ labels = sorted({str(x) for x in ds[train_split]["label"]})
112
+ label2id = {l: i for i, l in enumerate(labels)}
113
+ id2label = {i: l for l, i in label2id.items()}
114
+
115
+ encoded = _prepare(ds, tokenizer, label2id=label2id, max_length=cfg.max_length)
116
+ data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
117
+
118
+ model = AutoModelForSequenceClassification.from_pretrained(
119
+ cfg.model_name,
120
+ num_labels=len(labels),
121
+ label2id=label2id,
122
+ id2label=id2label,
123
+ )
124
+
125
+ push_to_hub = bool(os.getenv("HF_TOKEN")) or bool(os.getenv("HUGGINGFACE_HUB_TOKEN"))
126
+
127
+ args = TrainingArguments(
128
+ output_dir=cfg.output_dir,
129
+ num_train_epochs=cfg.num_train_epochs,
130
+ learning_rate=cfg.learning_rate,
131
+ per_device_train_batch_size=cfg.per_device_train_batch_size,
132
+ per_device_eval_batch_size=cfg.per_device_eval_batch_size,
133
+ weight_decay=cfg.weight_decay,
134
+ eval_strategy="epoch",
135
+ save_strategy="epoch",
136
+ logging_strategy="steps",
137
+ logging_steps=50,
138
+ load_best_model_at_end=True,
139
+ metric_for_best_model="accuracy",
140
+ seed=cfg.seed,
141
+ report_to="none",
142
+ push_to_hub=push_to_hub,
143
+ hub_model_id=cfg.hub_model_id if push_to_hub else None,
144
+ hub_strategy="end" if push_to_hub else "every_save",
145
+ )
146
+
147
+ trainer = Trainer(
148
+ model=model,
149
+ args=args,
150
+ train_dataset=encoded[train_split],
151
+ eval_dataset=encoded[test_split],
152
+ processing_class=tokenizer,
153
+ data_collator=data_collator,
154
+ compute_metrics=_compute_metrics,
155
+ )
156
+
157
+ trainer.train()
158
+ trainer.evaluate()
159
+
160
+ trainer.save_model(cfg.output_dir)
161
+ tokenizer.save_pretrained(cfg.output_dir)
162
+
163
+ if args.push_to_hub:
164
+ trainer.push_to_hub()
165
+
166
+ return 0
167
+
168
+
169
+ if __name__ == "__main__":
170
+ raise SystemExit(main())
171
+
uv.lock ADDED
The diff for this file is too large to render. See raw diff