yezdata commited on
Commit
bbcc716
·
verified ·
1 Parent(s): c682d69

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +308 -1
README.md CHANGED
@@ -1,3 +1,310 @@
1
  ---
 
 
2
  license: cc-by-nc-nd-4.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ language:
3
+ - en
4
  license: cc-by-nc-nd-4.0
5
+ library_name: generic
6
+ tags:
7
+ - emotion-recognition
8
+ - bayesian-deep-learning
9
+ - mc-dropout
10
+ - uncertainty-quantification
11
+ - multi-label-classification
12
+ datasets:
13
+ - go_emotions
14
+ metrics:
15
+ - precision
16
+ - recall
17
+ - f1
18
+ model-index:
19
+ - name: EmCoder (v1)
20
+ results:
21
+ - task:
22
+ type: text-classification
23
+ name: Multi-label Emotion Classification
24
+ dataset:
25
+ name: GoEmotions
26
+ type: go_emotions
27
+ split: test
28
+ metrics:
29
+ - name: Macro F1
30
+ type: f1
31
+ value: 0.440
32
+ - name: Macro Precision
33
+ type: precision
34
+ value: 0.408
35
+ - name: Macro Recall
36
+ type: recall
37
+ value: 0.495
38
+ ---
39
+
40
+ # EmCoder
41
+ > **Probabilistic Emotion Recognition & Uncertainty Quantification**<br>**28 Emotion multi-label classifier trained with MC Dropout methodology**
42
+
43
+
44
+
45
+ Unlike standard classifiers, EmCoder quantifies what it doesn't know using Monte Carlo Dropout, making it suitable for high-stakes AI pipelines.<br>
46
+ EmCoder is optimized for **MC Dropout inference**.
47
+
48
+
49
+
50
+ ## SOTA benchmark
51
+ ### Evaluation on the GoEmotions test split (macro avg metrics)
52
+ EmCoder achieves competitive F1-scores while being ~35% smaller than RoBERTa-base and ~45% smaller than ModernBERT, offering a superior efficiency-to-uncertainty ratio.
53
+ | Model | Precision | Recall | F1-Score | Params |
54
+ | :--- | :--- | :--- | :--- | :--- |
55
+ | **EmCoder (v1)** | **0.408** | **0.495** | **0.440** | **82.1M** |
56
+ | Google BERT (Original) | 0.400 | 0.630 | 0.460 | 110M |
57
+ | RoBERTa-base | 0.575 | 0.396 | 0.450 | 125M |
58
+ | ModernBERT-base | 0.652 | 0.443 | 0.500 | 149M |
59
+
60
+
61
+ ## How to use
62
+ > Since `.safetensors` files only store model weights and not the class logic, you need to use the provided `emcoder.py` to enable **MC Dropout inference**.<br>EmCoder v1.0 requires the `roberta-base` tokenizer for correct token-to-embedding mapping.
63
+ ### 1. Setup & Tokenization
64
+ ```python
65
+ from transformers import AutoTokenizer
66
+ from emcoder import EmCoder # Ensure emcoder.py is in your directory
67
+
68
+ # Load the same tokenizer used during training
69
+ tokenizer = AutoTokenizer.from_pretrained("roberta-base")
70
+
71
+ EMCODER_PATH = "path/to/emcoder"
72
+
73
+ # Initialize with same config as training
74
+ model = EmCoder.from_pretrained(EMCODER_PATH)
75
+ ```
76
+ ### 2. Bayesian inference
77
+ To obtain probabilistic outputs and uncertainty metrics, use the mc_forward method:
78
+ ```python
79
+ import torch
80
+
81
+ # Perform 50 stochastic passes
82
+ N_SAMPLES = 50
83
+ model.eval()
84
+
85
+ inputs = tokenizer("I am so happy you are here!", return_tensors="pt")
86
+ logits_mc = model.mc_forward(inputs['input_ids'], inputs['attention_mask'], n_samples=N_SAMPLES) # Automatically keeps Dropout active, even when in model.eval
87
+
88
+ # Bayesian Post-processing
89
+ # logits_mc shape: (n_samples, batch_size, 28)
90
+ probs_all = torch.sigmoid(logits_mc)
91
+
92
+ mean_probs = probs_all.mean(dim=0) # Mean Predicted Probability
93
+ uncertainty = probs_all.std(dim=0) # Epistemic Uncertainty (Standard Deviation)
94
+ ```
95
+
96
+
97
+
98
+ ## Model Architecture
99
+ ```mermaid
100
+ flowchart LR
101
+
102
+ subgraph InputGroup["Input Operations"]
103
+ direction TB
104
+ MCD_Loop(["MC-Inference Loop: N_samples"]):::LoopNode
105
+ ids["Batch IDs"]
106
+ mask["Batch Mask"]
107
+ end
108
+
109
+ subgraph EmCoderCore["EmCoder Core"]
110
+ direction LR
111
+ tok_emb["Token Embedding"]
112
+ ln_in["Input LayerNorm"]
113
+ Transformer["Transformer Encoder"]
114
+ final_norm["Final LayerNorm"]
115
+ Dropout1[("MC-Dropout")]
116
+ Dropout2[("MC-Dropout")]
117
+ end
118
+
119
+ subgraph Row1[" "]
120
+ direction LR
121
+ InputGroup
122
+ EmCoderCore
123
+ end
124
+
125
+ subgraph MLP["Classifier MLP"]
126
+ L_lin["Linear 1"]
127
+ Dropout3[("MC-Dropout")]
128
+ GELU["GELU"]
129
+ F_lin["Final Linear"]
130
+ end
131
+
132
+ subgraph ClassifierHead[" "]
133
+ direction TB
134
+ pool["Masked Mean Pooling"]
135
+ MLP
136
+ end
137
+
138
+ subgraph Row2[" "]
139
+ direction LR
140
+ ClassifierHead
141
+ Out(["Class LogitsMC
142
+ (n_samples, B, 28)"])
143
+
144
+ Avg["Bayesian Post-processing"]
145
+ end
146
+
147
+ tok_emb ==> ln_in
148
+ ln_in -.-> Dropout1
149
+ Dropout1 ==> Transformer
150
+ Transformer -.-> Dropout2
151
+ Dropout2 ==> final_norm
152
+ MCD_Loop -.-> ids
153
+ ids ==> tok_emb
154
+ final_norm ==> pool
155
+ mask ==> pool
156
+ pool ==> L_lin
157
+ L_lin -.-> Dropout3
158
+ Dropout3 ==> GELU
159
+ GELU ==> F_lin
160
+ F_lin ==> Out
161
+ Out ==> Avg
162
+ mask ==> Transformer
163
+
164
+ classDef MCD fill:#424242,stroke:#fbc02d,stroke-width:2px,stroke-dasharray: 5 5,color:#fff
165
+ classDef OutNode fill:#0d47a1,stroke:#1976d2,stroke-width:3px,color:#fff,font-weight:bold
166
+ classDef BayesNode fill:#3e2723,stroke:#8d6e63,stroke-width:2px,stroke-dasharray: 3 3,color:#fff
167
+ classDef LoopNode fill:#263238,stroke:#78909c,stroke-width:2px,color:#fff,font-style:italic
168
+ classDef LightNode fill:#212121,stroke:#90a4ae,color:#fff
169
+
170
+ class MCD_Loop LoopNode
171
+ class ids,mask,tok_emb,ln_in,Transformer,final_norm,L_lin,GELU,F_lin,pool LightNode
172
+ class Dropout1,Dropout2,Dropout3 MCD
173
+ class Out OutNode
174
+ class Avg BayesNode
175
+
176
+ style InputGroup fill:#1a1a1a,stroke:#444,color:#fff
177
+ style EmCoderCore fill:#2d1a2d,stroke:#6a1b9a,color:#fff
178
+ style MLP fill:#212121,stroke:#455a64,color:#fff
179
+ style ClassifierHead fill:#012a4a,stroke:#01497c,color:#fff
180
+ style Row1 fill:none,stroke:none
181
+ style Row2 fill:none,stroke:none
182
+
183
+ linkStyle 2 stroke:#fbc02d,stroke-width:2px,fill:none
184
+ linkStyle 5 stroke:#fbc02d,stroke-width:2px,fill:none
185
+ linkStyle 11 stroke:#fbc02d,stroke-width:2px,fill:none
186
+ ```
187
+
188
+
189
+ ### Optimization
190
+ The model is trained using a Weighted Bayesian Binary Cross Entropy loss:
191
+
192
+ $$
193
+ \mathcal{L}_{Bayesian} = \frac{1}{T} \sum_{t=1}^{T} \text{BCEWithLogits}(z^{(t)}, y; w)
194
+ $$
195
+
196
+ Where weights $w$ are calculated using a logarithmic class-balancing scale to handle extreme label imbalance:
197
+
198
+ $$
199
+ w_{c} = \max\left( 0.1, \min\left( 20, 1 + \ln \left( \frac{N_{neg,c} + \epsilon}{N_{pos,c} + \epsilon} \right) \right) \right)
200
+ $$
201
+
202
+
203
+
204
+ ## Performance
205
+ **Using threshold of 0.5 for binarizing predictions**
206
+ | | precision | recall | f1-score | support |
207
+ |:---------------|------------:|---------:|-----------:|----------:|
208
+ | micro avg | 0.494 | 0.596 | 0.54 | 6329 |
209
+ | macro avg | 0.408 | 0.495 | 0.44 | 6329 |
210
+ | weighted avg | 0.492 | 0.596 | 0.535 | 6329 |
211
+ | samples avg | 0.525 | 0.616 | 0.544 | 6329 |
212
+ |----------------|-------------|----------|------------|-----------|
213
+ | admiration | 0.541 | 0.673 | 0.599 | 504 |
214
+ | amusement | 0.688 | 0.909 | 0.783 | 264 |
215
+ | anger | 0.419 | 0.47 | 0.443 | 198 |
216
+ | annoyance | 0.31 | 0.25 | 0.277 | 320 |
217
+ | approval | 0.304 | 0.271 | 0.287 | 351 |
218
+ | caring | 0.229 | 0.281 | 0.252 | 135 |
219
+ | confusion | 0.26 | 0.497 | 0.342 | 153 |
220
+ | curiosity | 0.432 | 0.764 | 0.552 | 284 |
221
+ | desire | 0.453 | 0.518 | 0.483 | 83 |
222
+ | disappointment | 0.176 | 0.152 | 0.163 | 151 |
223
+ | disapproval | 0.279 | 0.404 | 0.33 | 267 |
224
+ | disgust | 0.447 | 0.545 | 0.491 | 123 |
225
+ | embarrassment | 0.325 | 0.351 | 0.338 | 37 |
226
+ | excitement | 0.288 | 0.427 | 0.344 | 103 |
227
+ | fear | 0.47 | 0.692 | 0.56 | 78 |
228
+ | gratitude | 0.834 | 0.943 | 0.885 | 352 |
229
+ | grief | 0 | 0 | 0 | 6 |
230
+ | joy | 0.445 | 0.652 | 0.529 | 161 |
231
+ | love | 0.724 | 0.895 | 0.801 | 238 |
232
+ | nervousness | 0.24 | 0.261 | 0.25 | 23 |
233
+ | optimism | 0.483 | 0.543 | 0.511 | 186 |
234
+ | pride | 0.667 | 0.375 | 0.48 | 16 |
235
+ | realization | 0.226 | 0.166 | 0.191 | 145 |
236
+ | relief | 0.222 | 0.182 | 0.2 | 11 |
237
+ | remorse | 0.516 | 0.857 | 0.644 | 56 |
238
+ | sadness | 0.405 | 0.545 | 0.464 | 156 |
239
+ | surprise | 0.429 | 0.539 | 0.478 | 141 |
240
+ | neutral | 0.602 | 0.695 | 0.645 | 1787 |
241
+
242
+
243
+
244
+ **Model uncertainty estimation**
245
+ ![epistemic_unc](outputs/epistemic_unc_scatter.png)
246
+
247
+ **Confusion matrix**
248
+ ![multi_label_confusion_matrix](outputs/confusion_matrix.png)
249
+
250
+
251
+
252
+ ## Workflow
253
+ ```mermaid
254
+ flowchart LR
255
+ classDef StageNode fill:#121212,stroke:#546e7a,color:#fff;
256
+ classDef HighlightNode fill:#4e342e,stroke:#ff7043,stroke-width:2px,color:#fff,font-weight:bold;
257
+
258
+ subgraph PT ["Phase 1: Pre-training"]
259
+ direction TB
260
+ OWT[(OpenWebText)]:::StageNode --> MLM[Masked Language Modeling]:::StageNode
261
+ MLM --> Core[Save EmCoderCore]:::StageNode
262
+ end
263
+
264
+ subgraph FT ["Phase 2: Fine-tuning"]
265
+ direction TB
266
+ Core --> Init[Init ClassificationHead]:::StageNode
267
+ GE[(GoEmotions)]:::StageNode --> WBT[Bayesian Fine-tuning]:::HighlightNode
268
+ WBT --> LogW[Log-weighted BCE Loss]:::StageNode
269
+ LogW --> Freeze[Step 0-500: Encoder Frozen]:::StageNode
270
+ end
271
+
272
+ subgraph EV ["Phase 3: Testing & Inference"]
273
+ direction TB
274
+ Freeze --> MCD[MC Dropout Inference]:::HighlightNode
275
+ MCD --> Unc[Uncertainty Estimation]:::HighlightNode
276
+
277
+ subgraph Metrics ["Analysis"]
278
+ Unc --> EPI[Epistemic: Model Confidence]:::StageNode
279
+ Unc --> ALE[Aleatoric: Data Ambiguity]:::StageNode
280
+ Unc --> CM[Test set metrics]:::StageNode
281
+ end
282
+ end
283
+
284
+ style PT fill:#0d1b2a,stroke:#1b263b,color:#fff
285
+ style FT fill:#2e1500,stroke:#5d2a00,color:#fff
286
+ style EV fill:#1b2e1b,stroke:#2d4a2d,color:#fff
287
+ style Metrics fill:#000,stroke:#333,color:#fff
288
+
289
+ linkStyle default stroke:#aaa,stroke-width:2px;
290
+ ```
291
+
292
+
293
+ ### Note
294
+ Note that this model was trained on GoEmotions dataset (social networks domain) and it may not generalize well to other domains.
295
+
296
+
297
+ ## Citation
298
+ If you use this model, please cite it as follows:
299
+
300
+ ```bibtex
301
+ @software{jez2026emcoder,
302
+ author = {Václav Jež},
303
+ title = {EmCoder: Probabilistic Emotion Recognition & Uncertainty Quantification},
304
+ year = {2026},
305
+ publisher = {GitHub},
306
+ journal = {GitHub repository},
307
+ howpublished = {\url{https://github.com/yezdata/emcoder}},
308
+ version = {1.0.0}
309
+ }
310
+ ```