permutans commited on
Commit
47ff542
·
verified ·
1 Parent(s): 84f7e31

Upload folder using huggingface_hub

Browse files
README.md CHANGED
@@ -2,7 +2,7 @@
2
  license: mit
3
  tags:
4
  - token-classification
5
- - bert
6
  - orality
7
  - linguistics
8
  - multi-label
@@ -11,7 +11,7 @@ language:
11
  metrics:
12
  - f1
13
  base_model:
14
- - google-bert/bert-base-uncased
15
  pipeline_tag: token-classification
16
  library_name: transformers
17
  datasets:
@@ -20,7 +20,7 @@ datasets:
20
 
21
  # Havelock Orality Token Classifier
22
 
23
- BERT-based token classifier for detecting **oral and literate markers** in text, based on Walter Ong's "Orality and Literacy" (1982).
24
 
25
  This model performs multi-label span-level detection of 53 rhetorical marker types, where each token independently carries B/I/O labels per type — allowing overlapping spans (e.g. a token that is simultaneously part of a concessive and a nested clause).
26
 
@@ -28,13 +28,13 @@ This model performs multi-label span-level detection of 53 rhetorical marker typ
28
 
29
  | Property | Value |
30
  |----------|-------|
31
- | Base model | `bert-base-uncased` |
32
  | Task | Multi-label token classification (independent B/I/O per type) |
33
  | Marker types | 53 (22 oral, 31 literate) |
34
- | Test macro F1 | **0.386** (per-type detection, binary positive = B or I) |
35
- | Training | 20 epochs, batch 24, lr 3e-5, fp16 |
36
  | Regularization | Mixout (p=0.1) — stochastic L2 anchor to pretrained weights |
37
- | Loss | Per-type weighted cross-entropy with inverse-frequency type weights |
38
  | Min examples | 150 (types below this threshold excluded) |
39
 
40
  ## Usage
@@ -118,61 +118,61 @@ Per-type detection F1 on test set (binary: B or I = positive, O = negative):
118
  ```
119
  Type Prec Rec F1 Sup
120
  ========================================================================
121
- literate_abstract_noun 0.209 0.329 0.255 420
122
- literate_additive_formal 0.243 0.479 0.322 71
123
- literate_agent_demoted 0.468 0.664 0.549 414
124
- literate_agentless_passive 0.555 0.648 0.598 1168
125
- literate_aside 0.481 0.469 0.475 469
126
- literate_categorical_statement 0.084 0.263 0.128 118
127
- literate_causal_explicit 0.314 0.386 0.347 272
128
- literate_citation 0.468 0.431 0.449 255
129
- literate_conceptual_metaphor 0.370 0.397 0.383 517
130
- literate_concessive 0.456 0.503 0.478 533
131
- literate_concessive_connector 0.250 0.603 0.353 63
132
- literate_concrete_setting 0.186 0.322 0.236 298
133
- literate_conditional 0.519 0.548 0.533 1514
134
- literate_contrastive 0.391 0.462 0.424 424
135
- literate_cross_reference 0.825 0.316 0.457 253
136
- literate_definitional_move 0.443 0.432 0.438 236
137
- literate_enumeration 0.147 0.306 0.198 297
138
- literate_epistemic_hedge 0.236 0.431 0.305 255
139
- literate_evidential 0.269 0.472 0.342 106
140
- literate_institutional_subject 0.157 0.450 0.233 111
141
- literate_list_structure 0.528 0.614 0.567 295
142
- literate_metadiscourse 0.355 0.407 0.379 447
143
- literate_nested_clauses 0.143 0.093 0.113 2044
144
- literate_nominalization 0.433 0.538 0.480 1013
145
- literate_objectifying_stance 0.451 0.575 0.506 113
146
- literate_probability 0.439 0.720 0.545 50
147
- literate_qualified_assertion 0.186 0.077 0.109 142
148
- literate_relative_chain 0.344 0.606 0.439 1456
149
- literate_technical_abbreviation 0.500 0.705 0.585 139
150
- literate_technical_term 0.278 0.423 0.336 825
151
- literate_temporal_embedding 0.174 0.253 0.206 400
152
- oral_anaphora 0.500 0.303 0.377 297
153
- oral_antithesis 0.298 0.339 0.317 561
154
- oral_discourse_formula 0.373 0.461 0.413 492
155
- oral_embodied_action 0.295 0.368 0.327 454
156
- oral_everyday_example 0.279 0.307 0.293 420
157
- oral_imperative 0.359 0.600 0.449 110
158
- oral_inclusive_we 0.579 0.668 0.620 681
159
- oral_intensifier_doubling 0.429 0.220 0.290 82
160
- oral_lexical_repetition 0.328 0.382 0.353 275
161
- oral_named_individual 0.359 0.712 0.478 573
162
- oral_parallelism 0.111 0.114 0.112 202
163
- oral_phatic_check 0.288 0.436 0.347 39
164
- oral_phatic_filler 0.389 0.527 0.448 146
165
- oral_rhetorical_question 0.581 0.892 0.703 1006
166
- oral_second_person 0.555 0.528 0.541 718
167
- oral_self_correction 0.293 0.357 0.322 115
168
- oral_sensory_detail 0.194 0.402 0.262 246
169
- oral_simple_conjunction 0.174 0.229 0.198 131
170
- oral_specific_place 0.453 0.751 0.565 406
171
- oral_temporal_anchor 0.223 0.704 0.339 257
172
- oral_tricolon 0.470 0.293 0.361 907
173
- oral_vocative 0.386 0.942 0.547 52
174
  ========================================================================
175
- Macro avg (types w/ support) 0.386
176
  ```
177
 
178
  </details>
@@ -180,17 +180,17 @@ Macro avg (types w/ support) 0.386
180
  **Missing labels (test set):** 0/53 — all types detected at least once.
181
 
182
  Notable patterns:
183
- - **Strong performers** (F1 > 0.5): rhetorical_question (0.703), inclusive_we (0.620), agentless_passive (0.598), technical_abbreviation (0.585), list_structure (0.567), specific_place (0.565), agent_demoted (0.549), vocative (0.547), probability (0.545), second_person (0.541), conditional (0.533), objectifying_stance (0.506)
184
- - **Weak performers** (F1 < 0.2): qualified_assertion (0.109), parallelism (0.112), nested_clauses (0.113), categorical_statement (0.128), enumeration (0.198), simple_conjunction (0.198)
185
- - **Precision-recall tradeoff**: Most types show higher recall than precision, indicating the model over-predicts markers. Notable exceptions include `cross_reference` (0.825 precision / 0.316 recall), `anaphora` (0.500 / 0.303), and `tricolon` (0.470 / 0.293), which remain high-precision but low-recall.
186
 
187
  ## Architecture
188
 
189
  Custom `MultiLabelTokenClassifier` with independent B/I/O heads per marker type:
190
  ```
191
- BertModel (bert-base-uncased)
192
  └── Dropout (p=0.1)
193
- └── Linear (768 → num_types × 3)
194
  └── Reshape to (batch, seq, num_types, 3)
195
  ```
196
 
@@ -199,13 +199,14 @@ Each marker type gets an independent 3-way O/B/I classification, so a token can
199
  ### Regularization
200
 
201
  - **Mixout** (p=0.1): During training, each backbone weight element has a 10% chance of being replaced by its pretrained value per forward pass, acting as a stochastic L2 anchor that prevents representation drift (Lee et al., 2019)
 
202
  - **Inverse-frequency type weights**: Rare marker types receive higher loss weighting
203
  - **Inverse-frequency OBI weights**: B and I classes upweighted relative to dominant O class
204
  - **Weighted random sampling**: Examples containing rarer markers sampled more frequently
205
 
206
  ### Initialization
207
 
208
- Fine-tuned from `bert-base-uncased`. Backbone linear layers wrapped with Mixout during training (frozen pretrained copy used as anchor). The classification head is randomly initialized:
209
  ```
210
  backbone.* layers → loaded from pretrained, anchored via Mixout
211
  classifier.weight → randomly initialized
@@ -214,9 +215,8 @@ classifier.bias → randomly initialized
214
 
215
  ## Limitations
216
 
217
- - **Recall-dominated errors**: Most types over-predict (recall > precision), producing false positives; downstream applications may need confidence thresholding
218
- - **Near-zero recall types**: `qualified_assertion` (0.077 recall), `nested_clauses` (0.093), and `parallelism` (0.114) are rarely detected despite being present in training data
219
- - **Low-precision types**: `categorical_statement` (0.084), `parallelism` (0.111), and `nested_clauses` (0.143) have precision below 0.15, meaning most predictions for those types are false positives
220
  - **Context window**: 128 tokens max; longer spans may be truncated
221
  - **Domain**: Trained primarily on historical/literary texts; may underperform on modern social media
222
  - **Subjectivity**: Some marker boundaries are inherently ambiguous
@@ -235,6 +235,7 @@ classifier.bias → randomly initialized
235
 
236
  - Ong, Walter J. *Orality and Literacy: The Technologizing of the Word*. Routledge, 1982.
237
  - Lee, C. et al. "Mixout: Effective Regularization to Finetune Large-scale Pretrained Language Models." ICLR 2020.
 
238
 
239
  ---
240
 
 
2
  license: mit
3
  tags:
4
  - token-classification
5
+ - modernbert
6
  - orality
7
  - linguistics
8
  - multi-label
 
11
  metrics:
12
  - f1
13
  base_model:
14
+ - answerdotai/ModernBERT-base
15
  pipeline_tag: token-classification
16
  library_name: transformers
17
  datasets:
 
20
 
21
  # Havelock Orality Token Classifier
22
 
23
+ ModernBERT-based token classifier for detecting **oral and literate markers** in text, based on Walter Ong's "Orality and Literacy" (1982).
24
 
25
  This model performs multi-label span-level detection of 53 rhetorical marker types, where each token independently carries B/I/O labels per type — allowing overlapping spans (e.g. a token that is simultaneously part of a concessive and a nested clause).
26
 
 
28
 
29
  | Property | Value |
30
  |----------|-------|
31
+ | Base model | `answerdotai/ModernBERT-base` |
32
  | Task | Multi-label token classification (independent B/I/O per type) |
33
  | Marker types | 53 (22 oral, 31 literate) |
34
+ | Test macro F1 | **0.378** (per-type detection, binary positive = B or I) |
35
+ | Training | 20 epochs, fp16 |
36
  | Regularization | Mixout (p=0.1) — stochastic L2 anchor to pretrained weights |
37
+ | Loss | Per-type focal loss (γ=2.0) with inverse-frequency OBI and type weights |
38
  | Min examples | 150 (types below this threshold excluded) |
39
 
40
  ## Usage
 
118
  ```
119
  Type Prec Rec F1 Sup
120
  ========================================================================
121
+ literate_abstract_noun 0.190 0.325 0.240 381
122
+ literate_additive_formal 0.246 0.556 0.341 27
123
+ literate_agent_demoted 0.404 0.368 0.386 304
124
+ literate_agentless_passive 0.575 0.607 0.591 1133
125
+ literate_aside 0.379 0.429 0.403 436
126
+ literate_categorical_statement 0.267 0.146 0.189 514
127
+ literate_causal_explicit 0.227 0.279 0.251 190
128
+ literate_citation 0.639 0.556 0.595 372
129
+ literate_conceptual_metaphor 0.310 0.364 0.335 415
130
+ literate_concessive 0.499 0.470 0.484 502
131
+ literate_concessive_connector 0.455 0.408 0.430 49
132
+ literate_concrete_setting 0.241 0.125 0.165 407
133
+ literate_conditional 0.369 0.630 0.466 760
134
+ literate_contrastive 0.310 0.428 0.360 341
135
+ literate_cross_reference 0.386 0.524 0.444 42
136
+ literate_definitional_move 0.395 0.185 0.252 81
137
+ literate_enumeration 0.495 0.483 0.489 775
138
+ literate_epistemic_hedge 0.421 0.481 0.449 445
139
+ literate_evidential 0.625 0.360 0.457 472
140
+ literate_institutional_subject 0.332 0.326 0.329 282
141
+ literate_list_structure 0.338 0.523 0.411 86
142
+ literate_metadiscourse 0.140 0.393 0.206 135
143
+ literate_nested_clauses 0.091 0.246 0.133 1169
144
+ literate_nominalization 0.499 0.612 0.549 991
145
+ literate_objectifying_stance 0.635 0.365 0.464 167
146
+ literate_probability 0.432 0.593 0.500 27
147
+ literate_qualified_assertion 0.143 0.100 0.118 40
148
+ literate_relative_chain 0.382 0.507 0.436 1424
149
+ literate_technical_abbreviation 0.667 0.711 0.688 225
150
+ literate_technical_term 0.280 0.375 0.321 715
151
+ literate_temporal_embedding 0.228 0.259 0.242 526
152
+ oral_anaphora 0.800 0.028 0.054 287
153
+ oral_antithesis 0.249 0.238 0.243 412
154
+ oral_discourse_formula 0.340 0.408 0.371 557
155
+ oral_embodied_action 0.280 0.391 0.326 425
156
+ oral_everyday_example 0.333 0.156 0.212 404
157
+ oral_imperative 0.591 0.662 0.625 293
158
+ oral_inclusive_we 0.516 0.632 0.568 622
159
+ oral_intensifier_doubling 0.680 0.200 0.309 85
160
+ oral_lexical_repetition 0.404 0.254 0.312 173
161
+ oral_named_individual 0.441 0.749 0.556 770
162
+ oral_parallelism 0.741 0.110 0.191 182
163
+ oral_phatic_check 0.611 0.733 0.667 30
164
+ oral_phatic_filler 0.174 0.409 0.244 93
165
+ oral_rhetorical_question 0.509 0.692 0.586 905
166
+ oral_second_person 0.576 0.552 0.564 811
167
+ oral_self_correction 0.158 0.235 0.189 51
168
+ oral_sensory_detail 0.285 0.169 0.212 461
169
+ oral_simple_conjunction 0.179 0.102 0.130 98
170
+ oral_specific_place 0.556 0.705 0.622 424
171
+ oral_temporal_anchor 0.410 0.559 0.473 546
172
+ oral_tricolon 0.299 0.119 0.171 553
173
+ oral_vocative 0.652 0.747 0.696 158
174
  ========================================================================
175
+ Macro avg (types w/ support) 0.378
176
  ```
177
 
178
  </details>
 
180
  **Missing labels (test set):** 0/53 — all types detected at least once.
181
 
182
  Notable patterns:
183
+ - **Strong performers** (F1 > 0.5): vocative (0.696), technical_abbreviation (0.688), phatic_check (0.667), imperative (0.625), specific_place (0.622), citation (0.595), agentless_passive (0.591), rhetorical_question (0.586), inclusive_we (0.568), second_person (0.564), named_individual (0.556), nominalization (0.549), probability (0.500)
184
+ - **Weak performers** (F1 < 0.2): anaphora (0.054), qualified_assertion (0.118), simple_conjunction (0.130), nested_clauses (0.133), concrete_setting (0.165), tricolon (0.171), categorical_statement (0.189), self_correction (0.189), parallelism (0.191)
185
+ - **Precision-recall tradeoff**: Most types show balanced precision/recall. Notable exceptions include `anaphora` (0.800 precision / 0.028 recall), `parallelism` (0.741 / 0.110), and `intensifier_doubling` (0.680 / 0.200), which remain high-precision but very low-recall.
186
 
187
  ## Architecture
188
 
189
  Custom `MultiLabelTokenClassifier` with independent B/I/O heads per marker type:
190
  ```
191
+ ModernBERT (answerdotai/ModernBERT-base)
192
  └── Dropout (p=0.1)
193
+ └── Linear (hidden_size → num_types × 3)
194
  └── Reshape to (batch, seq, num_types, 3)
195
  ```
196
 
 
199
  ### Regularization
200
 
201
  - **Mixout** (p=0.1): During training, each backbone weight element has a 10% chance of being replaced by its pretrained value per forward pass, acting as a stochastic L2 anchor that prevents representation drift (Lee et al., 2019)
202
+ - **Per-type focal loss** (γ=2.0): Focuses learning on hard examples, reducing the contribution of easy negatives
203
  - **Inverse-frequency type weights**: Rare marker types receive higher loss weighting
204
  - **Inverse-frequency OBI weights**: B and I classes upweighted relative to dominant O class
205
  - **Weighted random sampling**: Examples containing rarer markers sampled more frequently
206
 
207
  ### Initialization
208
 
209
+ Fine-tuned from `answerdotai/ModernBERT-base`. Backbone linear layers wrapped with Mixout during training (frozen pretrained copy used as anchor). The classification head is randomly initialized:
210
  ```
211
  backbone.* layers → loaded from pretrained, anchored via Mixout
212
  classifier.weight → randomly initialized
 
215
 
216
  ## Limitations
217
 
218
+ - **Near-zero recall types**: `anaphora` (0.028 recall), `simple_conjunction` (0.102), `parallelism` (0.110), and `tricolon` (0.119) are rarely detected despite being present in training data
219
+ - **Low-precision types**: `nested_clauses` (0.091), `metadiscourse` (0.140), and `qualified_assertion` (0.143) have precision below 0.15, meaning most predictions for those types are false positives
 
220
  - **Context window**: 128 tokens max; longer spans may be truncated
221
  - **Domain**: Trained primarily on historical/literary texts; may underperform on modern social media
222
  - **Subjectivity**: Some marker boundaries are inherently ambiguous
 
235
 
236
  - Ong, Walter J. *Orality and Literacy: The Technologizing of the Word*. Routledge, 1982.
237
  - Lee, C. et al. "Mixout: Effective Regularization to Finetune Large-scale Pretrained Language Models." ICLR 2020.
238
+ - Warner, A. et al. "Smarter, Better, Faster, Longer: A Modern Bidirectional Encoder for Fast, Memory Efficient, and Long Context Finetuning and Inference." 2024.
239
 
240
  ---
241
 
config.json CHANGED
@@ -1,19 +1,26 @@
1
  {
2
- "add_cross_attention": false,
3
  "architectures": [
4
- "BertForMaskedLM"
5
  ],
6
- "attention_probs_dropout_prob": 0.1,
 
7
  "auto_map": {
8
  "AutoModel": "modeling_havelock.HavelockTokenClassifier"
9
  },
10
- "bos_token_id": null,
11
- "classifier_dropout": null,
 
 
 
 
 
 
12
  "dtype": "float32",
13
- "eos_token_id": null,
 
 
14
  "gradient_checkpointing": false,
15
- "hidden_act": "gelu",
16
- "hidden_dropout_prob": 0.1,
17
  "hidden_size": 768,
18
  "id2label": {
19
  "0": "O-literate_abstract_noun",
@@ -176,9 +183,9 @@
176
  "98": "I-oral_antithesis",
177
  "99": "O-oral_discourse_formula"
178
  },
 
179
  "initializer_range": 0.02,
180
- "intermediate_size": 3072,
181
- "is_decoder": false,
182
  "label2id": {
183
  "B-literate_abstract_noun": 1,
184
  "B-literate_additive_formal": 4,
@@ -340,18 +347,59 @@
340
  "O-oral_tricolon": 153,
341
  "O-oral_vocative": 156
342
  },
343
- "layer_norm_eps": 1e-12,
344
- "max_position_embeddings": 512,
345
- "model_type": "bert",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
346
  "num_attention_heads": 12,
347
- "num_hidden_layers": 12,
348
  "num_types": 53,
349
- "pad_token_id": 0,
350
  "position_embedding_type": "absolute",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
351
  "tie_word_embeddings": true,
352
  "transformers_version": "5.0.0",
353
- "type_vocab_size": 2,
354
- "use_cache": true,
355
  "use_crf": true,
356
- "vocab_size": 30522
357
  }
 
1
  {
 
2
  "architectures": [
3
+ "ModernBertForMaskedLM"
4
  ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
  "auto_map": {
8
  "AutoModel": "modeling_havelock.HavelockTokenClassifier"
9
  },
10
+ "bos_token_id": 50281,
11
+ "classifier_activation": "gelu",
12
+ "classifier_bias": false,
13
+ "classifier_dropout": 0.0,
14
+ "classifier_pooling": "mean",
15
+ "cls_token_id": 50281,
16
+ "decoder_bias": true,
17
+ "deterministic_flash_attn": false,
18
  "dtype": "float32",
19
+ "embedding_dropout": 0.0,
20
+ "eos_token_id": 50282,
21
+ "global_attn_every_n_layers": 3,
22
  "gradient_checkpointing": false,
23
+ "hidden_activation": "gelu",
 
24
  "hidden_size": 768,
25
  "id2label": {
26
  "0": "O-literate_abstract_noun",
 
183
  "98": "I-oral_antithesis",
184
  "99": "O-oral_discourse_formula"
185
  },
186
+ "initializer_cutoff_factor": 2.0,
187
  "initializer_range": 0.02,
188
+ "intermediate_size": 1152,
 
189
  "label2id": {
190
  "B-literate_abstract_noun": 1,
191
  "B-literate_additive_formal": 4,
 
347
  "O-oral_tricolon": 153,
348
  "O-oral_vocative": 156
349
  },
350
+ "layer_norm_eps": 1e-05,
351
+ "layer_types": [
352
+ "full_attention",
353
+ "sliding_attention",
354
+ "sliding_attention",
355
+ "full_attention",
356
+ "sliding_attention",
357
+ "sliding_attention",
358
+ "full_attention",
359
+ "sliding_attention",
360
+ "sliding_attention",
361
+ "full_attention",
362
+ "sliding_attention",
363
+ "sliding_attention",
364
+ "full_attention",
365
+ "sliding_attention",
366
+ "sliding_attention",
367
+ "full_attention",
368
+ "sliding_attention",
369
+ "sliding_attention",
370
+ "full_attention",
371
+ "sliding_attention",
372
+ "sliding_attention",
373
+ "full_attention"
374
+ ],
375
+ "local_attention": 128,
376
+ "max_position_embeddings": 8192,
377
+ "mlp_bias": false,
378
+ "mlp_dropout": 0.0,
379
+ "model_type": "modernbert",
380
+ "norm_bias": false,
381
+ "norm_eps": 1e-05,
382
  "num_attention_heads": 12,
383
+ "num_hidden_layers": 22,
384
  "num_types": 53,
385
+ "pad_token_id": 50283,
386
  "position_embedding_type": "absolute",
387
+ "repad_logits_with_grad": false,
388
+ "rope_parameters": {
389
+ "full_attention": {
390
+ "rope_theta": 160000.0,
391
+ "rope_type": "default"
392
+ },
393
+ "sliding_attention": {
394
+ "rope_theta": 10000.0,
395
+ "rope_type": "default"
396
+ }
397
+ },
398
+ "sep_token_id": 50282,
399
+ "sparse_pred_ignore_index": -100,
400
+ "sparse_prediction": false,
401
  "tie_word_embeddings": true,
402
  "transformers_version": "5.0.0",
 
 
403
  "use_crf": true,
404
+ "vocab_size": 50368
405
  }
head_config.json CHANGED
@@ -1,5 +1,5 @@
1
  {
2
- "model_name": "bert-base-uncased",
3
  "num_types": 53,
4
  "hidden_size": 768
5
  }
 
1
  {
2
+ "model_name": "answerdotai/ModernBERT-base",
3
  "num_types": 53,
4
  "hidden_size": 768
5
  }
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:37d9c74b122fa304421948d1f1bc5ad1d686fb33eab36ae82079c1e8f4a03282
3
- size 436082548
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5048514ce9b2156eb090a211c30847979362fe5372cb94865373a00b5970726d
3
+ size 596563588
modeling_havelock.py CHANGED
@@ -1,143 +1,81 @@
1
- """Custom multi-label token classifier for HuggingFace Hub."""
2
 
3
  import torch
4
  import torch.nn as nn
5
- from transformers import BertModel, BertPreTrainedModel
6
 
7
 
8
- class MultiLabelCRF(nn.Module):
9
- """Independent CRF per marker type for multi-label BIO tagging."""
 
10
 
11
- def __init__(self, num_types: int) -> None:
12
- super().__init__()
13
  self.num_types = num_types
14
- self.transitions = nn.Parameter(torch.empty(num_types, 3, 3))
15
- self.start_transitions = nn.Parameter(torch.empty(num_types, 3))
16
- self.end_transitions = nn.Parameter(torch.empty(num_types, 3))
17
- # Placeholder — will be overwritten by loaded weights if present
18
- self.register_buffer("emission_bias", torch.zeros(1, 1, 1, 3))
19
- self._reset_parameters()
20
-
21
- def _reset_parameters(self) -> None:
22
- nn.init.uniform_(self.transitions, -0.1, 0.1)
23
- nn.init.uniform_(self.start_transitions, -0.1, 0.1)
24
- nn.init.uniform_(self.end_transitions, -0.1, 0.1)
25
- with torch.no_grad():
26
- self.transitions.data[:, 0, 2] = -10000.0
27
- self.start_transitions.data[:, 2] = -10000.0
28
-
29
- def _apply_emission_bias(self, emissions: torch.Tensor) -> torch.Tensor:
30
- if self.emission_bias is not None:
31
- return emissions + self.emission_bias
32
- return emissions
33
-
34
- def decode(self, emissions: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
35
- """Viterbi decoding.
36
-
37
- Args:
38
- emissions: (batch, seq, num_types, 3)
39
- mask: (batch, seq) boolean
40
-
41
- Returns: (batch, seq, num_types) best tag sequences
42
- """
43
- # Apply emission bias before decoding
44
- emissions = self._apply_emission_bias(emissions)
45
-
46
- batch, seq, num_types, _ = emissions.shape
47
-
48
- # Reshape to (batch*num_types, seq, 3)
49
- em = emissions.permute(0, 2, 1, 3).reshape(batch * num_types, seq, 3)
50
- mk = mask.unsqueeze(1).expand(-1, num_types, -1).reshape(batch * num_types, seq)
51
-
52
- BT = batch * num_types
53
-
54
- # Expand params across batch
55
- trans = (
56
- self.transitions.unsqueeze(0).expand(batch, -1, -1, -1).reshape(BT, 3, 3)
57
- )
58
- start = self.start_transitions.unsqueeze(0).expand(batch, -1, -1).reshape(BT, 3)
59
- end = self.end_transitions.unsqueeze(0).expand(batch, -1, -1).reshape(BT, 3)
60
-
61
- arange = torch.arange(BT, device=em.device)
62
- score = start + em[:, 0]
63
- history: list[torch.Tensor] = []
64
-
65
- for i in range(1, seq):
66
- broadcast = score.unsqueeze(2) + trans + em[:, i].unsqueeze(1)
67
- best_score, best_prev = broadcast.max(dim=1)
68
- score = torch.where(mk[:, i].unsqueeze(1), best_score, score)
69
- history.append(best_prev)
70
-
71
- score = score + end
72
- _, best_last = score.max(dim=1)
73
-
74
- best_paths = torch.zeros(BT, seq, dtype=torch.long, device=em.device)
75
- seq_lengths = mk.sum(dim=1).long()
76
- best_paths[arange, seq_lengths - 1] = best_last
77
-
78
- for i in range(seq - 2, -1, -1):
79
- prev_tag = history[i][arange, best_paths[:, i + 1]]
80
- should_update = i < (seq_lengths - 1)
81
- best_paths[:, i] = torch.where(should_update, prev_tag, best_paths[:, i])
82
 
83
- return best_paths.reshape(batch, num_types, seq).permute(0, 2, 1)
84
 
 
 
85
 
86
- class HavelockTokenClassifier(BertPreTrainedModel):
87
- """Multi-label BIO token classifier with independent O/B/I heads per marker type.
88
-
89
- Each token gets num_types independent 3-way classifications, allowing
90
- overlapping spans (e.g. a token simultaneously B-anaphora and I-concessive).
91
-
92
- Output logits shape: (batch, seq_len, num_types, 3)
93
- """
94
-
95
- def __init__(self, config):
96
  super().__init__(config)
97
  self.num_types = config.num_types
98
- self.use_crf = getattr(config, "use_crf", False)
99
- self.bert = BertModel(config, add_pooling_layer=False)
 
 
 
 
 
 
100
  self.dropout = nn.Dropout(config.hidden_dropout_prob)
101
  self.classifier = nn.Linear(config.hidden_size, config.num_types * 3)
102
 
103
  if self.use_crf:
104
  self.crf = MultiLabelCRF(config.num_types)
105
 
106
- self.post_init()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
  def forward(self, input_ids, attention_mask=None, **kwargs):
109
- hidden = self.bert(
110
  input_ids=input_ids, attention_mask=attention_mask
111
  ).last_hidden_state
112
  hidden = self.dropout(hidden)
113
  logits = self.classifier(hidden)
114
  batch, seq, _ = logits.shape
115
- logits = logits.view(batch, seq, self.num_types, 3)
116
-
117
- # If CRF is available and we're not training, return decoded tags
118
- # stacked with logits so callers can access either
119
- if self.use_crf and not self.training:
120
- mask = (
121
- attention_mask.bool()
122
- if attention_mask is not None
123
- else torch.ones(batch, seq, dtype=torch.bool, device=logits.device)
124
- )
125
- # Return logits — callers use .decode() or we add a decode method
126
- # For HF pipeline compat, return logits; users call decode separately
127
- pass
128
-
129
- return logits
130
 
131
  def decode(self, input_ids, attention_mask=None):
132
- """Run forward pass and return Viterbi-decoded tags."""
133
  logits = self.forward(input_ids, attention_mask)
134
  if self.use_crf:
135
  mask = (
136
  attention_mask.bool()
137
  if attention_mask is not None
138
- else torch.ones(
139
- logits.shape[:2], dtype=torch.bool, device=logits.device
140
- )
141
  )
142
  return self.crf.decode(logits, mask)
143
- return logits.argmax(dim=-1)
 
1
+ """Custom multi-label token classifier backbone-agnostic."""
2
 
3
  import torch
4
  import torch.nn as nn
5
+ from transformers import AutoConfig, AutoModel, PreTrainedModel, PretrainedConfig
6
 
7
 
8
+ class HavelockTokenConfig(PretrainedConfig):
9
+ """Config that wraps any backbone config + our custom fields."""
10
+ model_type = "havelock_token_classifier"
11
 
12
+ def __init__(self, num_types: int = 1, use_crf: bool = False, **kwargs):
13
+ super().__init__(**kwargs)
14
  self.num_types = num_types
15
+ self.use_crf = use_crf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
 
17
 
18
+ class HavelockTokenClassifier(PreTrainedModel):
19
+ config_class = HavelockTokenConfig
20
 
21
+ def __init__(self, config: HavelockTokenConfig, backbone: PreTrainedModel | None = None):
 
 
 
 
 
 
 
 
 
22
  super().__init__(config)
23
  self.num_types = config.num_types
24
+ self.use_crf = config.use_crf
25
+
26
+ # Accept injected backbone (from_pretrained path) or build from config
27
+ if backbone is not None:
28
+ self.backbone = backbone
29
+ else:
30
+ self.backbone = AutoModel.from_config(config)
31
+
32
  self.dropout = nn.Dropout(config.hidden_dropout_prob)
33
  self.classifier = nn.Linear(config.hidden_size, config.num_types * 3)
34
 
35
  if self.use_crf:
36
  self.crf = MultiLabelCRF(config.num_types)
37
 
38
+ @classmethod
39
+ def from_backbone(
40
+ cls,
41
+ model_name: str,
42
+ num_types: int,
43
+ use_crf: bool = False,
44
+ obi_bias: torch.Tensor | None = None,
45
+ ) -> "HavelockTokenClassifier":
46
+ """Build from a pretrained backbone name — the training entrypoint."""
47
+ backbone = AutoModel.from_pretrained(model_name)
48
+ backbone_config = backbone.config
49
+
50
+ config = HavelockTokenConfig(
51
+ num_types=num_types,
52
+ use_crf=use_crf,
53
+ **backbone_config.to_dict(),
54
+ )
55
+
56
+ model = cls(config, backbone=backbone)
57
+
58
+ if use_crf and obi_bias is not None:
59
+ model.crf.emission_bias = obi_bias.reshape(1, 1, 1, 3)
60
+
61
+ return model
62
 
63
  def forward(self, input_ids, attention_mask=None, **kwargs):
64
+ hidden = self.backbone(
65
  input_ids=input_ids, attention_mask=attention_mask
66
  ).last_hidden_state
67
  hidden = self.dropout(hidden)
68
  logits = self.classifier(hidden)
69
  batch, seq, _ = logits.shape
70
+ return logits.view(batch, seq, self.num_types, 3)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  def decode(self, input_ids, attention_mask=None):
 
73
  logits = self.forward(input_ids, attention_mask)
74
  if self.use_crf:
75
  mask = (
76
  attention_mask.bool()
77
  if attention_mask is not None
78
+ else torch.ones(logits.shape[:2], dtype=torch.bool, device=logits.device)
 
 
79
  )
80
  return self.crf.decode(logits, mask)
81
+ return logits.argmax(dim=-1)
tokenizer.json CHANGED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json CHANGED
@@ -1,14 +1,16 @@
1
  {
2
  "backend": "tokenizers",
 
3
  "cls_token": "[CLS]",
4
- "do_lower_case": true,
5
  "is_local": false,
6
  "mask_token": "[MASK]",
7
- "model_max_length": 512,
 
 
 
 
8
  "pad_token": "[PAD]",
9
  "sep_token": "[SEP]",
10
- "strip_accents": null,
11
- "tokenize_chinese_chars": true,
12
- "tokenizer_class": "BertTokenizer",
13
  "unk_token": "[UNK]"
14
  }
 
1
  {
2
  "backend": "tokenizers",
3
+ "clean_up_tokenization_spaces": true,
4
  "cls_token": "[CLS]",
 
5
  "is_local": false,
6
  "mask_token": "[MASK]",
7
+ "model_input_names": [
8
+ "input_ids",
9
+ "attention_mask"
10
+ ],
11
+ "model_max_length": 8192,
12
  "pad_token": "[PAD]",
13
  "sep_token": "[SEP]",
14
+ "tokenizer_class": "TokenizersBackend",
 
 
15
  "unk_token": "[UNK]"
16
  }