Instructions to use FlowRank/mailSort with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use FlowRank/mailSort with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-classification", model="FlowRank/mailSort")# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("FlowRank/mailSort", dtype="auto") - Notebooks
- Google Colab
- Kaggle
feat(training): pipeline minimal train/test + artefacts HF
Browse files- Entraîne un classifieur à partir de FlowRank/labeled_emails
- Ajoute scripts d’évaluation et de préparation des artefacts sous model/
- Documente l’usage (train/test/publish) dans README
Co-authored-by: Hilarion Lefuneste <hilarionlefuneste@tutamail.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
- .gitignore +15 -0
- .python-version +1 -0
- README.md +111 -0
- main.py +7 -0
- model/config.json +52 -0
- model/model.safetensors +3 -0
- model/tokenizer.json +0 -0
- model/tokenizer_config.json +15 -0
- model/training_args.bin +3 -0
- pyproject.toml +21 -0
- src/mailsort/__init__.py +2 -0
- src/mailsort/eval.py +112 -0
- src/mailsort/prepare_model.py +61 -0
- src/mailsort/train.py +171 -0
- uv.lock +0 -0
.gitignore
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.venv/
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.pyc
|
| 4 |
+
.DS_Store
|
| 5 |
+
|
| 6 |
+
# training outputs (local only)
|
| 7 |
+
outputs/
|
| 8 |
+
outputs_smoke/
|
| 9 |
+
|
| 10 |
+
# HF caches
|
| 11 |
+
.cache/
|
| 12 |
+
**/.cache/
|
| 13 |
+
|
| 14 |
+
# build metadata
|
| 15 |
+
*.egg-info/
|
.python-version
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
3.11
|
README.md
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
library_name: transformers
|
| 3 |
+
pipeline_tag: text-classification
|
| 4 |
+
tags:
|
| 5 |
+
- email
|
| 6 |
+
- text-classification
|
| 7 |
+
language:
|
| 8 |
+
- en
|
| 9 |
+
---
|
| 10 |
+
|
| 11 |
+
## mailSort
|
| 12 |
+
|
| 13 |
+
Repo minimal (Python) pour **entraîner, évaluer et publier** un modèle de classification multi-classes d’e-mails à partir du dataset Hugging Face [`FlowRank/labeled_emails`](https://huggingface.co/datasets/FlowRank/labeled_emails).
|
| 14 |
+
|
| 15 |
+
Le script principal est `mailsort.train` (Transformers `Trainer`).
|
| 16 |
+
|
| 17 |
+
### Prérequis
|
| 18 |
+
|
| 19 |
+
- Python géré par `uv` (ce repo est prévu pour être lancé avec `uv run`)
|
| 20 |
+
- (Optionnel) un GPU CUDA si tu veux accélérer l’entraînement
|
| 21 |
+
- (Optionnel) un token HF si tu veux publier sur le Hub
|
| 22 |
+
|
| 23 |
+
### Installation
|
| 24 |
+
|
| 25 |
+
`uv` installe/synchronise automatiquement les dépendances la première fois.
|
| 26 |
+
|
| 27 |
+
```bash
|
| 28 |
+
uv sync
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
### Entraîner + évaluer (train + test)
|
| 32 |
+
|
| 33 |
+
Par défaut, le script charge le dataset **depuis le Hub** et utilise ses splits `train` et `test`.
|
| 34 |
+
L’évaluation se fait à chaque epoch, puis une évaluation finale est exécutée à la fin.
|
| 35 |
+
Les artefacts (modèle, tokenizer) sont sauvegardés dans `outputs/` (ou le dossier passé via `--output-dir`).
|
| 36 |
+
|
| 37 |
+
```bash
|
| 38 |
+
uv run python -m mailsort.train \
|
| 39 |
+
--dataset-id FlowRank/labeled_emails \
|
| 40 |
+
--model-name distilbert-base-uncased \
|
| 41 |
+
--hub-model-id FlowRank/mailSort \
|
| 42 |
+
--num-train-epochs 2
|
| 43 |
+
```
|
| 44 |
+
|
| 45 |
+
### Tester / évaluer uniquement
|
| 46 |
+
|
| 47 |
+
Le script n’a pas (encore) de mode “eval-only”.
|
| 48 |
+
Le **minimum** pour faire uniquement une passe rapide est de mettre `--num-train-epochs 0` (ce qui évite l’entraînement) et de garder la phase `evaluate`.
|
| 49 |
+
|
| 50 |
+
```bash
|
| 51 |
+
uv run python -m mailsort.train --num-train-epochs 0
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
### Évaluer sur le split `test` du dataset (recommandé)
|
| 55 |
+
|
| 56 |
+
Après ton entraînement dans `outputs/`, tu peux évaluer proprement sur le **split `test`** de `FlowRank/labeled_emails` :
|
| 57 |
+
|
| 58 |
+
```bash
|
| 59 |
+
uv run python -m mailsort.eval --model outputs --dataset-id FlowRank/labeled_emails --split test
|
| 60 |
+
```
|
| 61 |
+
|
| 62 |
+
(Optionnel) Pour un test rapide :
|
| 63 |
+
|
| 64 |
+
```bash
|
| 65 |
+
uv run python -m mailsort.eval --model outputs --split test --max-samples 200
|
| 66 |
+
```
|
| 67 |
+
|
| 68 |
+
### Publier sur le Hub (FlowRank/mailSort)
|
| 69 |
+
|
| 70 |
+
Le push se fait automatiquement si la variable d’environnement `HF_TOKEN` (ou `HUGGINGFACE_HUB_TOKEN`) est définie.
|
| 71 |
+
|
| 72 |
+
```bash
|
| 73 |
+
export HF_TOKEN="..."
|
| 74 |
+
uv run python -m mailsort.train --hub-model-id FlowRank/mailSort
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
### Publier via Git (README à la racine + artefacts dans `model/`)
|
| 78 |
+
|
| 79 |
+
Pour avoir un repo Hugging Face “complet” (doc + poids) tout en gardant une structure propre, on met :
|
| 80 |
+
|
| 81 |
+
- `README.md` à la racine (documentation + model card)
|
| 82 |
+
- les artefacts (config, poids, tokenizer) dans `model/`
|
| 83 |
+
|
| 84 |
+
Hugging Face pourra charger le modèle via `subfolder="model"`.
|
| 85 |
+
|
| 86 |
+
1) Préparer le dossier `model/` à partir de `outputs/` :
|
| 87 |
+
|
| 88 |
+
```bash
|
| 89 |
+
uv run python -m mailsort.prepare_model --outputs-dir outputs --model-dir model
|
| 90 |
+
```
|
| 91 |
+
|
| 92 |
+
2) Commit + push vers le repo Hugging Face `FlowRank/mailSort` :
|
| 93 |
+
|
| 94 |
+
```bash
|
| 95 |
+
git add README.md model
|
| 96 |
+
git commit -m "Add model artifacts under model/ + docs"
|
| 97 |
+
git push
|
| 98 |
+
```
|
| 99 |
+
|
| 100 |
+
### Inférence (utiliser le modèle publié)
|
| 101 |
+
|
| 102 |
+
```python
|
| 103 |
+
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
|
| 104 |
+
|
| 105 |
+
tok = AutoTokenizer.from_pretrained("FlowRank/mailSort", subfolder="model")
|
| 106 |
+
model = AutoModelForSequenceClassification.from_pretrained("FlowRank/mailSort", subfolder="model")
|
| 107 |
+
clf = pipeline("text-classification", model=model, tokenizer=tok, truncation=True)
|
| 108 |
+
text = "Subject: Insurance claim\n\nBody: Hello, I need to update my policy..."
|
| 109 |
+
print(clf(text))
|
| 110 |
+
```
|
| 111 |
+
|
main.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from mailsort.train import main
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
if __name__ == "__main__":
|
| 7 |
+
raise SystemExit(main())
|
model/config.json
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"activation": "gelu",
|
| 3 |
+
"architectures": [
|
| 4 |
+
"DistilBertForSequenceClassification"
|
| 5 |
+
],
|
| 6 |
+
"attention_dropout": 0.1,
|
| 7 |
+
"bos_token_id": null,
|
| 8 |
+
"dim": 768,
|
| 9 |
+
"dropout": 0.1,
|
| 10 |
+
"dtype": "float32",
|
| 11 |
+
"eos_token_id": null,
|
| 12 |
+
"hidden_dim": 3072,
|
| 13 |
+
"id2label": {
|
| 14 |
+
"0": "family",
|
| 15 |
+
"1": "finance",
|
| 16 |
+
"2": "games",
|
| 17 |
+
"3": "human resources",
|
| 18 |
+
"4": "medical",
|
| 19 |
+
"5": "pets",
|
| 20 |
+
"6": "school",
|
| 21 |
+
"7": "software engineering",
|
| 22 |
+
"8": "sport",
|
| 23 |
+
"9": "work/airbus"
|
| 24 |
+
},
|
| 25 |
+
"initializer_range": 0.02,
|
| 26 |
+
"label2id": {
|
| 27 |
+
"family": 0,
|
| 28 |
+
"finance": 1,
|
| 29 |
+
"games": 2,
|
| 30 |
+
"human resources": 3,
|
| 31 |
+
"medical": 4,
|
| 32 |
+
"pets": 5,
|
| 33 |
+
"school": 6,
|
| 34 |
+
"software engineering": 7,
|
| 35 |
+
"sport": 8,
|
| 36 |
+
"work/airbus": 9
|
| 37 |
+
},
|
| 38 |
+
"max_position_embeddings": 512,
|
| 39 |
+
"model_type": "distilbert",
|
| 40 |
+
"n_heads": 12,
|
| 41 |
+
"n_layers": 6,
|
| 42 |
+
"pad_token_id": 0,
|
| 43 |
+
"problem_type": "single_label_classification",
|
| 44 |
+
"qa_dropout": 0.1,
|
| 45 |
+
"seq_classif_dropout": 0.2,
|
| 46 |
+
"sinusoidal_pos_embds": false,
|
| 47 |
+
"tie_weights_": true,
|
| 48 |
+
"tie_word_embeddings": true,
|
| 49 |
+
"transformers_version": "5.8.0",
|
| 50 |
+
"use_cache": false,
|
| 51 |
+
"vocab_size": 30522
|
| 52 |
+
}
|
model/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e88532e8b0d0583f7a96aa686f19f0ab2256ee741f6b55773cbb8f36f520c276
|
| 3 |
+
size 267857176
|
model/tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
model/tokenizer_config.json
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"backend": "tokenizers",
|
| 3 |
+
"cls_token": "[CLS]",
|
| 4 |
+
"do_lower_case": true,
|
| 5 |
+
"is_local": false,
|
| 6 |
+
"local_files_only": false,
|
| 7 |
+
"mask_token": "[MASK]",
|
| 8 |
+
"model_max_length": 512,
|
| 9 |
+
"pad_token": "[PAD]",
|
| 10 |
+
"sep_token": "[SEP]",
|
| 11 |
+
"strip_accents": null,
|
| 12 |
+
"tokenize_chinese_chars": true,
|
| 13 |
+
"tokenizer_class": "BertTokenizer",
|
| 14 |
+
"unk_token": "[UNK]"
|
| 15 |
+
}
|
model/training_args.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e5f07cdfe1546cecc7b90ed8a2b923d1de3ed2925d0c1de004f011b09e643764
|
| 3 |
+
size 5265
|
pyproject.toml
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "mailsort"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
description = "Add your description here"
|
| 5 |
+
requires-python = ">=3.11"
|
| 6 |
+
dependencies = [
|
| 7 |
+
"accelerate>=1.13.0",
|
| 8 |
+
"datasets>=4.8.5",
|
| 9 |
+
"huggingface-hub>=1.13.0",
|
| 10 |
+
"safetensors>=0.7.0",
|
| 11 |
+
"torch>=2.11.0",
|
| 12 |
+
"transformers>=5.8.0",
|
| 13 |
+
]
|
| 14 |
+
|
| 15 |
+
[project.scripts]
|
| 16 |
+
mailsort-train = "mailsort.train:main"
|
| 17 |
+
mailsort-eval = "mailsort.eval:main"
|
| 18 |
+
mailsort-prepare-model = "mailsort.prepare_model:main"
|
| 19 |
+
|
| 20 |
+
[tool.uv]
|
| 21 |
+
package = true
|
src/mailsort/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__all__ = []
|
| 2 |
+
|
src/mailsort/eval.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
from collections import Counter, defaultdict
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
|
| 7 |
+
from datasets import load_dataset
|
| 8 |
+
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@dataclass(frozen=True)
|
| 12 |
+
class Config:
|
| 13 |
+
dataset_id: str
|
| 14 |
+
model_id_or_path: str
|
| 15 |
+
subfolder: str | None
|
| 16 |
+
split: str
|
| 17 |
+
max_samples: int | None
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def _build_text(subject: str, body: str) -> str:
|
| 21 |
+
subject = "" if subject is None else str(subject)
|
| 22 |
+
body = "" if body is None else str(body)
|
| 23 |
+
if subject and body:
|
| 24 |
+
return f"Subject: {subject}\n\nBody: {body}"
|
| 25 |
+
return subject or body
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def _parse_args() -> Config:
|
| 29 |
+
p = argparse.ArgumentParser(description="Evaluate a model (local or Hub) against a HF dataset split.")
|
| 30 |
+
p.add_argument("--dataset-id", default="FlowRank/labeled_emails")
|
| 31 |
+
p.add_argument("--model", default="outputs", help="Local path OR Hugging Face repo id (e.g. FlowRank/mailSort).")
|
| 32 |
+
p.add_argument("--subfolder", default=None, help="Optional subfolder (e.g. model).")
|
| 33 |
+
p.add_argument("--split", default="test", help="Which split to evaluate (e.g. test).")
|
| 34 |
+
p.add_argument("--max-samples", type=int, default=None, help="Limit evaluation to N samples.")
|
| 35 |
+
a = p.parse_args()
|
| 36 |
+
return Config(
|
| 37 |
+
dataset_id=a.dataset_id,
|
| 38 |
+
model_id_or_path=a.model,
|
| 39 |
+
subfolder=a.subfolder,
|
| 40 |
+
split=a.split,
|
| 41 |
+
max_samples=a.max_samples,
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def main() -> int:
|
| 46 |
+
cfg = _parse_args()
|
| 47 |
+
|
| 48 |
+
ds = load_dataset(cfg.dataset_id)
|
| 49 |
+
if cfg.split not in ds:
|
| 50 |
+
raise SystemExit(f"Split '{cfg.split}' not found. Available: {list(ds.keys())}")
|
| 51 |
+
|
| 52 |
+
rows = ds[cfg.split]
|
| 53 |
+
if cfg.max_samples is not None:
|
| 54 |
+
rows = rows.select(range(min(cfg.max_samples, len(rows))))
|
| 55 |
+
|
| 56 |
+
kwargs = {}
|
| 57 |
+
if cfg.subfolder:
|
| 58 |
+
kwargs["subfolder"] = cfg.subfolder
|
| 59 |
+
|
| 60 |
+
tokenizer = AutoTokenizer.from_pretrained(cfg.model_id_or_path, **kwargs)
|
| 61 |
+
model = AutoModelForSequenceClassification.from_pretrained(cfg.model_id_or_path, **kwargs)
|
| 62 |
+
clf = pipeline(
|
| 63 |
+
"text-classification",
|
| 64 |
+
model=model,
|
| 65 |
+
tokenizer=tokenizer,
|
| 66 |
+
truncation=True,
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
correct = 0
|
| 70 |
+
total = 0
|
| 71 |
+
per_label = Counter()
|
| 72 |
+
per_label_ok = Counter()
|
| 73 |
+
confusion = defaultdict(Counter) # true -> pred -> count
|
| 74 |
+
|
| 75 |
+
for ex in rows:
|
| 76 |
+
text = _build_text(ex.get("subject"), ex.get("body"))
|
| 77 |
+
true_label = str(ex["label"])
|
| 78 |
+
pred = clf(text, top_k=1)[0]["label"]
|
| 79 |
+
|
| 80 |
+
total += 1
|
| 81 |
+
per_label[true_label] += 1
|
| 82 |
+
confusion[true_label][pred] += 1
|
| 83 |
+
|
| 84 |
+
if pred == true_label:
|
| 85 |
+
correct += 1
|
| 86 |
+
per_label_ok[true_label] += 1
|
| 87 |
+
|
| 88 |
+
acc = correct / total if total else 0.0
|
| 89 |
+
print(f"dataset={cfg.dataset_id} split={cfg.split} samples={total}")
|
| 90 |
+
print(f"accuracy={acc:.4f} ({correct}/{total})")
|
| 91 |
+
print("\nper-label accuracy:")
|
| 92 |
+
for label in sorted(per_label.keys()):
|
| 93 |
+
denom = per_label[label]
|
| 94 |
+
num = per_label_ok[label]
|
| 95 |
+
print(f"- {label}: {num}/{denom} = {num/denom:.4f}")
|
| 96 |
+
|
| 97 |
+
# print top confusions per label (lightweight)
|
| 98 |
+
print("\ncommon confusions (top-2 per true label):")
|
| 99 |
+
for true_label in sorted(confusion.keys()):
|
| 100 |
+
most = confusion[true_label].most_common(3)
|
| 101 |
+
# skip perfect-only rows
|
| 102 |
+
if len(most) == 1 and most[0][0] == true_label:
|
| 103 |
+
continue
|
| 104 |
+
top = ", ".join([f"{pred}:{cnt}" for pred, cnt in most])
|
| 105 |
+
print(f"- {true_label}: {top}")
|
| 106 |
+
|
| 107 |
+
return 0
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
if __name__ == "__main__":
|
| 111 |
+
raise SystemExit(main())
|
| 112 |
+
|
src/mailsort/prepare_model.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import shutil
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@dataclass(frozen=True)
|
| 10 |
+
class Config:
|
| 11 |
+
outputs_dir: Path
|
| 12 |
+
model_dir: Path
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def _parse_args() -> Config:
|
| 16 |
+
p = argparse.ArgumentParser(description="Prepare model/ folder from an outputs/ training directory.")
|
| 17 |
+
p.add_argument("--outputs-dir", default="outputs", help="Training output directory (from mailsort.train).")
|
| 18 |
+
p.add_argument("--model-dir", default="model", help="Target folder to commit/push to Hugging Face.")
|
| 19 |
+
a = p.parse_args()
|
| 20 |
+
return Config(outputs_dir=Path(a.outputs_dir), model_dir=Path(a.model_dir))
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def main() -> int:
|
| 24 |
+
cfg = _parse_args()
|
| 25 |
+
|
| 26 |
+
if not cfg.outputs_dir.exists():
|
| 27 |
+
raise SystemExit(f"outputs-dir not found: {cfg.outputs_dir}")
|
| 28 |
+
|
| 29 |
+
cfg.model_dir.mkdir(parents=True, exist_ok=True)
|
| 30 |
+
|
| 31 |
+
# clean target (keep it explicit and predictable)
|
| 32 |
+
for p in cfg.model_dir.iterdir():
|
| 33 |
+
if p.is_dir():
|
| 34 |
+
shutil.rmtree(p)
|
| 35 |
+
else:
|
| 36 |
+
p.unlink()
|
| 37 |
+
|
| 38 |
+
# Copy only final artifacts (root files), ignore trainer checkpoints.
|
| 39 |
+
for p in cfg.outputs_dir.iterdir():
|
| 40 |
+
if p.is_dir():
|
| 41 |
+
# ignore checkpoint-* dirs
|
| 42 |
+
continue
|
| 43 |
+
shutil.copy2(p, cfg.model_dir / p.name)
|
| 44 |
+
|
| 45 |
+
# sanity: expected minimum files
|
| 46 |
+
expected_any = [
|
| 47 |
+
"config.json",
|
| 48 |
+
"tokenizer.json",
|
| 49 |
+
"tokenizer_config.json",
|
| 50 |
+
]
|
| 51 |
+
missing = [n for n in expected_any if not (cfg.model_dir / n).exists()]
|
| 52 |
+
if missing:
|
| 53 |
+
raise SystemExit(f"Missing expected files in {cfg.model_dir}: {missing}")
|
| 54 |
+
|
| 55 |
+
print(f"Prepared {cfg.model_dir} from {cfg.outputs_dir}")
|
| 56 |
+
return 0
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
if __name__ == "__main__":
|
| 60 |
+
raise SystemExit(main())
|
| 61 |
+
|
src/mailsort/train.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import os
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
from datasets import DatasetDict, load_dataset
|
| 9 |
+
from transformers import (
|
| 10 |
+
AutoModelForSequenceClassification,
|
| 11 |
+
AutoTokenizer,
|
| 12 |
+
DataCollatorWithPadding,
|
| 13 |
+
Trainer,
|
| 14 |
+
TrainingArguments,
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@dataclass(frozen=True)
|
| 19 |
+
class Config:
|
| 20 |
+
dataset_id: str
|
| 21 |
+
model_name: str
|
| 22 |
+
hub_model_id: str
|
| 23 |
+
output_dir: str
|
| 24 |
+
max_length: int
|
| 25 |
+
num_train_epochs: float
|
| 26 |
+
per_device_train_batch_size: int
|
| 27 |
+
per_device_eval_batch_size: int
|
| 28 |
+
learning_rate: float
|
| 29 |
+
weight_decay: float
|
| 30 |
+
seed: int
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _build_text(subject: str, body: str) -> str:
|
| 34 |
+
subject = "" if subject is None else str(subject)
|
| 35 |
+
body = "" if body is None else str(body)
|
| 36 |
+
if subject and body:
|
| 37 |
+
return f"Subject: {subject}\n\nBody: {body}"
|
| 38 |
+
return subject or body
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def _parse_args() -> Config:
|
| 42 |
+
p = argparse.ArgumentParser(description="Train & push email classifier to Hugging Face Hub.")
|
| 43 |
+
p.add_argument("--dataset-id", default="FlowRank/labeled_emails")
|
| 44 |
+
p.add_argument("--model-name", default="distilbert-base-uncased")
|
| 45 |
+
p.add_argument("--hub-model-id", default="FlowRank/mailSort")
|
| 46 |
+
p.add_argument("--output-dir", default="outputs")
|
| 47 |
+
p.add_argument("--max-length", type=int, default=256)
|
| 48 |
+
p.add_argument("--num-train-epochs", type=float, default=2)
|
| 49 |
+
p.add_argument("--per-device-train-batch-size", type=int, default=16)
|
| 50 |
+
p.add_argument("--per-device-eval-batch-size", type=int, default=32)
|
| 51 |
+
p.add_argument("--learning-rate", type=float, default=2e-5)
|
| 52 |
+
p.add_argument("--weight-decay", type=float, default=0.01)
|
| 53 |
+
p.add_argument("--seed", type=int, default=42)
|
| 54 |
+
a = p.parse_args()
|
| 55 |
+
return Config(
|
| 56 |
+
dataset_id=a.dataset_id,
|
| 57 |
+
model_name=a.model_name,
|
| 58 |
+
hub_model_id=a.hub_model_id,
|
| 59 |
+
output_dir=a.output_dir,
|
| 60 |
+
max_length=a.max_length,
|
| 61 |
+
num_train_epochs=a.num_train_epochs,
|
| 62 |
+
per_device_train_batch_size=a.per_device_train_batch_size,
|
| 63 |
+
per_device_eval_batch_size=a.per_device_eval_batch_size,
|
| 64 |
+
learning_rate=a.learning_rate,
|
| 65 |
+
weight_decay=a.weight_decay,
|
| 66 |
+
seed=a.seed,
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def _load_ds(dataset_id: str, seed: int) -> DatasetDict:
|
| 71 |
+
ds = load_dataset(dataset_id)
|
| 72 |
+
if "train" in ds and "test" in ds:
|
| 73 |
+
return ds # already split
|
| 74 |
+
# fallback: split if only a single split exists
|
| 75 |
+
if "train" in ds and "test" not in ds:
|
| 76 |
+
return ds["train"].train_test_split(test_size=0.1, seed=seed)
|
| 77 |
+
# if weird structure, just return as-is and let Trainer fail loudly
|
| 78 |
+
return ds
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def _prepare(ds: DatasetDict, tokenizer: AutoTokenizer, label2id: dict[str, int], max_length: int) -> DatasetDict:
|
| 82 |
+
def preprocess(ex):
|
| 83 |
+
text = _build_text(ex.get("subject"), ex.get("body"))
|
| 84 |
+
out = tokenizer(text, truncation=True, max_length=max_length)
|
| 85 |
+
out["labels"] = label2id[str(ex["label"])]
|
| 86 |
+
return out
|
| 87 |
+
|
| 88 |
+
cols_to_remove = [c for c in ["subject", "body", "label"] if c in ds["train"].column_names]
|
| 89 |
+
return ds.map(preprocess, remove_columns=cols_to_remove)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def _compute_metrics(eval_pred):
|
| 93 |
+
logits, labels = eval_pred
|
| 94 |
+
preds = np.argmax(logits, axis=-1)
|
| 95 |
+
acc = (preds == labels).astype(np.float32).mean().item()
|
| 96 |
+
return {"accuracy": acc}
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def main() -> int:
|
| 100 |
+
cfg = _parse_args()
|
| 101 |
+
|
| 102 |
+
ds = _load_ds(cfg.dataset_id, seed=cfg.seed)
|
| 103 |
+
train_split = "train" if "train" in ds else list(ds.keys())[0]
|
| 104 |
+
test_split = "test" if "test" in ds else ("validation" if "validation" in ds else None)
|
| 105 |
+
|
| 106 |
+
if test_split is None:
|
| 107 |
+
raise SystemExit(f"Dataset must have a test/validation split. Found: {list(ds.keys())}")
|
| 108 |
+
|
| 109 |
+
tokenizer = AutoTokenizer.from_pretrained(cfg.model_name, use_fast=True)
|
| 110 |
+
|
| 111 |
+
labels = sorted({str(x) for x in ds[train_split]["label"]})
|
| 112 |
+
label2id = {l: i for i, l in enumerate(labels)}
|
| 113 |
+
id2label = {i: l for l, i in label2id.items()}
|
| 114 |
+
|
| 115 |
+
encoded = _prepare(ds, tokenizer, label2id=label2id, max_length=cfg.max_length)
|
| 116 |
+
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
|
| 117 |
+
|
| 118 |
+
model = AutoModelForSequenceClassification.from_pretrained(
|
| 119 |
+
cfg.model_name,
|
| 120 |
+
num_labels=len(labels),
|
| 121 |
+
label2id=label2id,
|
| 122 |
+
id2label=id2label,
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
push_to_hub = bool(os.getenv("HF_TOKEN")) or bool(os.getenv("HUGGINGFACE_HUB_TOKEN"))
|
| 126 |
+
|
| 127 |
+
args = TrainingArguments(
|
| 128 |
+
output_dir=cfg.output_dir,
|
| 129 |
+
num_train_epochs=cfg.num_train_epochs,
|
| 130 |
+
learning_rate=cfg.learning_rate,
|
| 131 |
+
per_device_train_batch_size=cfg.per_device_train_batch_size,
|
| 132 |
+
per_device_eval_batch_size=cfg.per_device_eval_batch_size,
|
| 133 |
+
weight_decay=cfg.weight_decay,
|
| 134 |
+
eval_strategy="epoch",
|
| 135 |
+
save_strategy="epoch",
|
| 136 |
+
logging_strategy="steps",
|
| 137 |
+
logging_steps=50,
|
| 138 |
+
load_best_model_at_end=True,
|
| 139 |
+
metric_for_best_model="accuracy",
|
| 140 |
+
seed=cfg.seed,
|
| 141 |
+
report_to="none",
|
| 142 |
+
push_to_hub=push_to_hub,
|
| 143 |
+
hub_model_id=cfg.hub_model_id if push_to_hub else None,
|
| 144 |
+
hub_strategy="end" if push_to_hub else "every_save",
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
trainer = Trainer(
|
| 148 |
+
model=model,
|
| 149 |
+
args=args,
|
| 150 |
+
train_dataset=encoded[train_split],
|
| 151 |
+
eval_dataset=encoded[test_split],
|
| 152 |
+
processing_class=tokenizer,
|
| 153 |
+
data_collator=data_collator,
|
| 154 |
+
compute_metrics=_compute_metrics,
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
trainer.train()
|
| 158 |
+
trainer.evaluate()
|
| 159 |
+
|
| 160 |
+
trainer.save_model(cfg.output_dir)
|
| 161 |
+
tokenizer.save_pretrained(cfg.output_dir)
|
| 162 |
+
|
| 163 |
+
if args.push_to_hub:
|
| 164 |
+
trainer.push_to_hub()
|
| 165 |
+
|
| 166 |
+
return 0
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
if __name__ == "__main__":
|
| 170 |
+
raise SystemExit(main())
|
| 171 |
+
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|