Ippoboi commited on
Commit
0c3e5e8
·
verified ·
1 Parent(s): 61b65a5

create readme

Browse files
Files changed (1) hide show
  1. README.md +240 -0
README.md ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ base_model:
3
+ - cservan/malbert-base-cased-128k
4
+ license: apache-2.0
5
+ language:
6
+ - en
7
+ - fr
8
+ pipeline_tag: text-classification
9
+ inference: false
10
+ tags:
11
+ - classification
12
+ - emails
13
+ - multilingual
14
+ - albert
15
+ - onnx
16
+ - mobile
17
+ - int8
18
+ widget:
19
+ - text: "Subject: Your order has shipped\n\nBody: Your order #12345 is on its way and will arrive by Monday."
20
+ example_title: Transaction (EN)
21
+ - text: "Subject: Réunion demain\n\nBody: Salut, peut-on reporter notre réunion de 14h à 15h ? Dis-moi."
22
+ example_title: Personal (FR)
23
+ - text: "Subject: Weekly Newsletter\n\nBody: Check out our latest deals! 50% off everything this weekend."
24
+ example_title: Newsletter (EN)
25
+ - text: "Subject: Alerte de sécurité\n\nBody: Une nouvelle connexion à votre compte depuis Paris, France. Vérifiez que c'est bien vous."
26
+ example_title: Alert (FR)
27
+ ---
28
+
29
+ # Email Classifier (mALBERT ONNX)
30
+
31
+ A dual-head **mALBERT** classifier for email category + action prediction, optimized for on-device inference using ONNX Runtime. Bilingual (English + French), 24M parameters, 50.7 MB after INT8 quantization.
32
+
33
+ ## Model Description
34
+
35
+ Classifies emails into 5 categories and predicts whether the recipient should take action:
36
+
37
+ | Category | Description |
38
+ |----------|-------------|
39
+ | **PERSONAL** | Direct 1:1 human communication, calendar invites from real people, direct messages. Excludes platform notifications. |
40
+ | **NEWSLETTER** | Marketing, promotions, subscribed content. Includes weekly digests, year-in-review recaps, marketing-flavored surveys with rewards. |
41
+ | **TRANSACTION** | Money or order events: receipts, charges, refunds, shipping confirmations with order/booking IDs, payslips, money-transfer notifications. |
42
+ | **ALERT** | Account, security, or infrastructure messages: password resets, login alerts, CI failures, booking-bound expiry, satisfaction surveys without rewards, named-product update notifications. |
43
+ | **SOCIAL** | Platform activity *between people*: post mentions, comment notifications, PR review requests from real users. Excludes automated platform mail (those are ALERT). |
44
+
45
+ The action flag is `true` only when the email requires a concrete response tied to something the user owns or initiated — pay to keep an existing booking, verify a code you requested, accept/decline a calendar invite, reply to a 1:1 message, security event needing verification, or a support ticket follow-up.
46
+
47
+ ### Output Format
48
+
49
+ Single forward pass producing two tensors:
50
+ - `category_probs`: Float32[5] — softmax probabilities per category (argmax = predicted category)
51
+ - `action_prob`: Float32[1] — sigmoid probability of action required (threshold 0.5)
52
+
53
+ No text generation, no decoder, no beam search.
54
+
55
+ **Example:**
56
+
57
+ ```
58
+ Input: "Subject: Your order has shipped\n\nBody: Your order #12345 is on its way..."
59
+ Output: category_probs → TRANSACTION (0.94), action_prob → 0.08 (NO_ACTION)
60
+ ```
61
+
62
+ ## Intended Use
63
+
64
+ - **Primary:** On-device email triage in mobile apps (iOS/Android)
65
+ - **Runtime:** ONNX Runtime React Native
66
+ - **Use case:** Prioritizing inbox, filtering noise, surfacing actionable emails
67
+
68
+ ## Model Details
69
+
70
+ | Attribute | Value |
71
+ |-----------|-------|
72
+ | Base Model | `cservan/malbert-base-cased-128k` |
73
+ | Parameters | ~24M |
74
+ | Architecture | ALBERT encoder (parameter-shared, 1 physical block × 12 virtual layers) + dual classification heads |
75
+ | Pooling | `pooler_output` (SOP-pretrained linear + tanh) |
76
+ | ONNX Size | 50.7 MB (INT8 quantized, 1.8× compression from FP32) |
77
+ | Max Sequence | 384 tokens |
78
+ | Tokenizer | SentencePiece Unigram (128K vocab, French-aware) |
79
+ | Hidden Size | 768 |
80
+ | Special Tokens | `[CLS]=2`, `[SEP]=3`, `<pad>=0`, `<unk>=1` |
81
+
82
+ ## Performance
83
+
84
+ Test set metrics (250 emails, balanced across categories, EN+FR):
85
+
86
+ | Metric | Score |
87
+ |--------|-------|
88
+ | **Category Accuracy** | **86.0%** (single seed) / **88.4%** (2-seed soft-vote ensemble) |
89
+ | **Action Accuracy** | **84.8%** |
90
+ | Quantization | INT8 dynamic, 20/20 PyTorch↔ONNX argmax parity |
91
+
92
+ ### Per-language breakdown (single seed)
93
+
94
+ | | English | French |
95
+ |---|---|---|
96
+ | Category accuracy | 85.4% | **87.0%** |
97
+ | Action accuracy | 89.2% | 77.2% |
98
+
99
+ Notable: French slightly outperforms English on category — the multilingual signal is symmetric. Action accuracy retains an EN advantage (~12 pts) reflecting heavier representation of EN action patterns in training data.
100
+
101
+ ### Per-class F1 (single seed)
102
+
103
+ | Class | Precision | Recall | F1 |
104
+ |---|---|---|---|
105
+ | ALERT | 0.885 | 0.900 | 0.893 |
106
+ | NEWSLETTER | 0.771 | 0.900 | 0.831 |
107
+ | PERSONAL | 0.917 | 0.892 | 0.904 |
108
+ | SOCIAL | 0.862 | 0.758 | 0.807 |
109
+ | TRANSACTION | 0.907 | 0.817 | 0.860 |
110
+
111
+ ## Training Data
112
+
113
+ - **Source:** Personal Gmail inboxes (anonymized)
114
+ - **Languages:** English, French
115
+ - **Size:** 2,005 train / 251 val / 250 test (balanced)
116
+ - **Labeling:** Human-annotated with category + action flag, prompt-assisted with v7 labeling rules (precise tie-breakers for booking-bound deadlines, marketing recaps with reward language, CI/security automation, curated personalized outreach, satisfaction surveys with/without incentives)
117
+ - **Input format:** `Subject: ...\n\nBody: ...` (no instruction prefix)
118
+
119
+ ## How to Use
120
+
121
+ ### ONNX Runtime (React Native)
122
+
123
+ ```typescript
124
+ import { InferenceSession, Tensor } from 'onnxruntime-react-native';
125
+
126
+ const session = await InferenceSession.create('model.onnx');
127
+
128
+ const outputs = await session.run({
129
+ input_ids: inputIdsTensor, // int64[1, seq_len]
130
+ attention_mask: attentionMaskTensor, // int64[1, seq_len]
131
+ token_type_ids: tokenTypeIdsTensor, // int64[1, seq_len], all zeros
132
+ });
133
+
134
+ const categoryProbs = outputs.category_probs.data; // Float32[5]
135
+ const actionProb = outputs.action_prob.data[0]; // Float32
136
+ ```
137
+
138
+ ### Python (PyTorch reference)
139
+
140
+ ```python
141
+ from transformers import AutoTokenizer
142
+ import torch
143
+
144
+ tokenizer = AutoTokenizer.from_pretrained("Ippoboi/malbert-email-classifier")
145
+ # Load DualHeadClassifier from checkpoint (see ml/scripts/train_classifier.py)
146
+
147
+ text = "Subject: Réunion demain\n\nBody: Peut-on reporter à 15h ?"
148
+ inputs = tokenizer(text, return_tensors="pt", max_length=384, truncation=True)
149
+
150
+ with torch.no_grad():
151
+ cat_logits, act_logits = model(inputs["input_ids"], inputs["attention_mask"])
152
+ category = ["ALERT", "NEWSLETTER", "PERSONAL", "SOCIAL", "TRANSACTION"][cat_logits.argmax()]
153
+ action = torch.sigmoid(act_logits).item() > 0.5
154
+ ```
155
+
156
+ ### ONNX Runtime (Python)
157
+
158
+ ```python
159
+ import onnxruntime as ort
160
+ from transformers import AutoTokenizer
161
+ import numpy as np
162
+
163
+ session = ort.InferenceSession("model.onnx")
164
+ tokenizer = AutoTokenizer.from_pretrained("Ippoboi/malbert-email-classifier")
165
+
166
+ inputs = tokenizer(
167
+ "Subject: Your order has shipped\n\nBody: ...",
168
+ return_tensors="np",
169
+ max_length=384,
170
+ truncation=True,
171
+ padding="max_length",
172
+ )
173
+ cat_probs, act_prob = session.run(
174
+ ["category_probs", "action_prob"],
175
+ {
176
+ "input_ids": inputs["input_ids"].astype(np.int64),
177
+ "attention_mask": inputs["attention_mask"].astype(np.int64),
178
+ "token_type_ids": np.zeros_like(inputs["input_ids"], dtype=np.int64),
179
+ },
180
+ )
181
+ categories = ["ALERT", "NEWSLETTER", "PERSONAL", "SOCIAL", "TRANSACTION"]
182
+ print(categories[cat_probs[0].argmax()], "action:", act_prob[0] > 0.5)
183
+ ```
184
+
185
+ ## Files
186
+
187
+ | File | Size | Description |
188
+ |------|------|-------------|
189
+ | `model.onnx` | 50.7 MB | INT8 quantized ONNX model |
190
+ | `tokenizer.json` | 8.2 MB | Fast tokenizer (SentencePiece Unigram, 128K vocab) |
191
+ | `spiece.model` | 2.3 MB | Raw SentencePiece vocab (optional, for Python reload) |
192
+ | `tokenizer_config.json` | 1.4 KB | Tokenizer config |
193
+ | `special_tokens_map.json` | 970 B | Special token names → IDs |
194
+
195
+ ## Architecture
196
+
197
+ ```
198
+ Input → ALBERT Encoder (12 virtual layers × 1 shared block, hidden=768)
199
+
200
+ pooler_output (Linear+tanh on [CLS])
201
+
202
+ ┌─────┴─────┐
203
+ ↓ ↓
204
+ Category Head Action Head
205
+ Linear(768→5) Linear(768→1)
206
+ ↓ ↓
207
+ softmax sigmoid
208
+ ↓ ↓
209
+ category_probs action_prob
210
+ ```
211
+
212
+ ALBERT shares one physical transformer block across all 12 virtual layers. This gives ~24M total parameters (vs ~110M for an equivalent BERT-base) at the cost of representational capacity per virtual depth.
213
+
214
+ ## Compared to Previous Model (MiniLM v1)
215
+
216
+ | | MiniLM v1 | mALBERT v3 (this) |
217
+ |---|---|---|
218
+ | Base architecture | XLM-R encoder, independent layers | ALBERT, parameter-shared |
219
+ | Parameters | ~117M | ~24M |
220
+ | ONNX size | 113 MB | **50.7 MB** |
221
+ | Max sequence | 256 | **384** |
222
+ | Vocab size | 250K | 128K |
223
+ | Category accuracy | 92.0% | 86.0% / 88.4% (ensemble) |
224
+ | Action accuracy | 82.8% | **84.8%** |
225
+ | FR cat parity | EN-favored | **EN/FR symmetric** |
226
+
227
+ mALBERT v3 trades raw category accuracy for **less than half the on-device footprint**, **wider context** (384 vs 256 tokens), and **balanced multilingual performance**. Action accuracy is higher; category accuracy is lower in absolute terms but the language gap closes.
228
+
229
+ ## Limitations
230
+
231
+ - Trained on personal email patterns; may not generalize to enterprise/corporate email styles
232
+ - Classification accuracy depends on text quality (plain text preferred over heavy HTML)
233
+ - French action accuracy lags English by ~12 points; the v7 labeling prompt is EN-leaning in its action examples
234
+ - SOCIAL is the weakest category (F1 0.81) — smallest training class (268 examples) and shares features with NEWSLETTER for platform-mass-emails
235
+ - 384-token cap may truncate long emails; ~17% of training emails exceeded this limit
236
+ - ALBERT parameter sharing limits representational depth; for harder boundaries, a non-shared encoder (mDeBERTa-v3-base, MiniLM-L12) would have more capacity at higher inference cost
237
+
238
+ ## License
239
+
240
+ Apache 2.0