thivy commited on
Commit
d83a67e
·
verified ·
1 Parent(s): 8723c4e

Training in progress, step 500, checkpoint

Browse files
last-checkpoint/1_SpladePooling/config.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "pooling_strategy": "max",
3
+ "activation_function": "relu",
4
+ "word_embedding_dimension": 51200
5
+ }
last-checkpoint/README.md ADDED
@@ -0,0 +1,554 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - 'no'
4
+ - da
5
+ - sv
6
+ license: mit
7
+ tags:
8
+ - sentence-transformers
9
+ - sparse-encoder
10
+ - sparse
11
+ - splade
12
+ - generated_from_trainer
13
+ - dataset_size:333547
14
+ - loss:SpladeLoss
15
+ - loss:SparseMultipleNegativesRankingLoss
16
+ - loss:FlopsLoss
17
+ base_model: ltg/norbert4-base
18
+ widget:
19
+ - text: "\n \nJeg begyndte at forstå, hvilke vældige kræfter min lille historie\
20
+ \ havde sluppet løs.\n \n "
21
+ - text: "\n \nIfølge Empires job-bibel skal en direktør-assistent ikke dække bord.\n\
22
+ \ \n "
23
+ - text: "\n \nDet kan du da ikke gøre!\n "
24
+ - text: "\n \nJeg må købe flere sherbet fountains.\n \n "
25
+ - text: Søren Kierkegaard, den danske filosof og teolog, var dybt fascineret af begrebet
26
+ tro. I sine mange skrifter udforskede han troens natur, dens paradokser og dens
27
+ betydning for det individuelle liv. Han anså troen for at være et ”spring i det
28
+ forlommede”, en akt af vilje der overstiger fornuften. I værker som ”Frygt og
29
+ Trekken” og ”Sygdommen til Døden” analyserede han troens relation til angst, desperation
30
+ og den eksistentielle krise. Kierkegaards tanker om tro har haft stor indflydelse
31
+ på kristen teologi og eksistentialisme.
32
+ pipeline_tag: feature-extraction
33
+ library_name: sentence-transformers
34
+ metrics:
35
+ - dot_accuracy@1
36
+ - dot_accuracy@3
37
+ - dot_accuracy@5
38
+ - dot_accuracy@10
39
+ - dot_precision@1
40
+ - dot_precision@3
41
+ - dot_precision@5
42
+ - dot_precision@10
43
+ - dot_recall@1
44
+ - dot_recall@3
45
+ - dot_recall@5
46
+ - dot_recall@10
47
+ - dot_ndcg@10
48
+ - dot_mrr@10
49
+ - dot_map@100
50
+ - query_active_dims
51
+ - query_sparsity_ratio
52
+ - corpus_active_dims
53
+ - corpus_sparsity_ratio
54
+ - avg_flops
55
+ model-index:
56
+ - name: Regular SPLADE NorBERT4-base — Retrieval-Only Training
57
+ results:
58
+ - task:
59
+ type: sparse-information-retrieval
60
+ name: Sparse Information Retrieval
61
+ dataset:
62
+ name: NanoNFCorpus
63
+ type: NanoNFCorpus
64
+ metrics:
65
+ - type: dot_accuracy@1
66
+ value: 0.02
67
+ name: Dot Accuracy@1
68
+ - type: dot_accuracy@3
69
+ value: 0.08
70
+ name: Dot Accuracy@3
71
+ - type: dot_accuracy@5
72
+ value: 0.08
73
+ name: Dot Accuracy@5
74
+ - type: dot_accuracy@10
75
+ value: 0.12
76
+ name: Dot Accuracy@10
77
+ - type: dot_precision@1
78
+ value: 0.02
79
+ name: Dot Precision@1
80
+ - type: dot_precision@3
81
+ value: 0.03333333333333333
82
+ name: Dot Precision@3
83
+ - type: dot_precision@5
84
+ value: 0.032
85
+ name: Dot Precision@5
86
+ - type: dot_precision@10
87
+ value: 0.026000000000000006
88
+ name: Dot Precision@10
89
+ - type: dot_recall@1
90
+ value: 7.905138339920947e-05
91
+ name: Dot Recall@1
92
+ - type: dot_recall@3
93
+ value: 0.003312410422185988
94
+ name: Dot Recall@3
95
+ - type: dot_recall@5
96
+ value: 0.004545769460972766
97
+ name: Dot Recall@5
98
+ - type: dot_recall@10
99
+ value: 0.006349071275176555
100
+ name: Dot Recall@10
101
+ - type: dot_ndcg@10
102
+ value: 0.027178706104522946
103
+ name: Dot Ndcg@10
104
+ - type: dot_mrr@10
105
+ value: 0.05088888888888889
106
+ name: Dot Mrr@10
107
+ - type: dot_map@100
108
+ value: 0.006747512755501429
109
+ name: Dot Map@100
110
+ - type: query_active_dims
111
+ value: 51200.0
112
+ name: Query Active Dims
113
+ - type: query_sparsity_ratio
114
+ value: 0.0
115
+ name: Query Sparsity Ratio
116
+ - type: corpus_active_dims
117
+ value: 51200.0
118
+ name: Corpus Active Dims
119
+ - type: corpus_sparsity_ratio
120
+ value: 0.0
121
+ name: Corpus Sparsity Ratio
122
+ - type: avg_flops
123
+ value: 51200.0
124
+ name: Avg Flops
125
+ ---
126
+
127
+ # Regular SPLADE NorBERT4-base — Retrieval-Only Training
128
+
129
+ This is a [SPLADE Sparse Encoder](https://www.sbert.net/docs/sparse_encoder/usage/usage.html) model finetuned from [ltg/norbert4-base](https://huggingface.co/ltg/norbert4-base) using the [sentence-transformers](https://www.SBERT.net) library. It maps sentences & paragraphs to a 51200-dimensional sparse vector space and can be used for semantic search and sparse retrieval.
130
+ ## Model Details
131
+
132
+ ### Model Description
133
+ - **Model Type:** SPLADE Sparse Encoder
134
+ - **Base model:** [ltg/norbert4-base](https://huggingface.co/ltg/norbert4-base) <!-- at revision f04e0e824de9ff9a08767727dc8891d38fddd032 -->
135
+ - **Maximum Sequence Length:** None tokens
136
+ - **Output Dimensionality:** 51200 dimensions
137
+ - **Similarity Function:** Dot Product
138
+ <!-- - **Training Dataset:** Unknown -->
139
+ - **Languages:** no, da, sv
140
+ - **License:** mit
141
+
142
+ ### Model Sources
143
+
144
+ - **Documentation:** [Sentence Transformers Documentation](https://sbert.net)
145
+ - **Documentation:** [Sparse Encoder Documentation](https://www.sbert.net/docs/sparse_encoder/usage/usage.html)
146
+ - **Repository:** [Sentence Transformers on GitHub](https://github.com/huggingface/sentence-transformers)
147
+ - **Hugging Face:** [Sparse Encoders on Hugging Face](https://huggingface.co/models?library=sentence-transformers&other=sparse-encoder)
148
+
149
+ ### Full Model Architecture
150
+
151
+ ```
152
+ SparseEncoder(
153
+ (0): MLMTransformer({'max_seq_length': None, 'do_lower_case': False, 'architecture': 'GptBertForMaskedLM'})
154
+ (1): SpladePooling({'pooling_strategy': 'max', 'activation_function': 'relu', 'word_embedding_dimension': 51200})
155
+ )
156
+ ```
157
+
158
+ ## Usage
159
+
160
+ ### Direct Usage (Sentence Transformers)
161
+
162
+ First install the Sentence Transformers library:
163
+
164
+ ```bash
165
+ pip install -U sentence-transformers
166
+ ```
167
+
168
+ Then you can load this model and run inference.
169
+ ```python
170
+ from sentence_transformers import SparseEncoder
171
+
172
+ # Download from the 🤗 Hub
173
+ model = SparseEncoder("thivy/norbert4-base-splade-retrieval")
174
+ # Run inference
175
+ sentences = [
176
+ '\n \nJeg vil ikke ha noen innvendinger.\n \n ',
177
+ '\n \nJeg ville ikke have nogen indvendinger.\n \n ',
178
+ 'Søren Kierkegaard, den danske filosof og teolog, var dybt fascineret af begrebet tro. I sine mange skrifter udforskede han troens natur, dens paradokser og dens betydning for det individuelle liv. Han anså troen for at være et ”spring i det forlommede”, en akt af vilje der overstiger fornuften. I værker som ”Frygt og Trekken” og ”Sygdommen til Døden” analyserede han troens relation til angst, desperation og den eksistentielle krise. Kierkegaards tanker om tro har haft stor indflydelse på kristen teologi og eksistentialisme.',
179
+ ]
180
+ embeddings = model.encode(sentences)
181
+ print(embeddings.shape)
182
+ # [3, 51200]
183
+
184
+ # Get the similarity scores for the embeddings
185
+ similarities = model.similarity(embeddings, embeddings)
186
+ print(similarities)
187
+ # tensor([[ 8.0400, 6.6640, 6.9193],
188
+ # [ 6.6640, 10.4033, 9.1223],
189
+ # [ 6.9193, 9.1223, 20.8932]])
190
+ ```
191
+
192
+ <!--
193
+ ### Direct Usage (Transformers)
194
+
195
+ <details><summary>Click to see the direct usage in Transformers</summary>
196
+
197
+ </details>
198
+ -->
199
+
200
+ <!--
201
+ ### Downstream Usage (Sentence Transformers)
202
+
203
+ You can finetune this model on your own dataset.
204
+
205
+ <details><summary>Click to expand</summary>
206
+
207
+ </details>
208
+ -->
209
+
210
+ <!--
211
+ ### Out-of-Scope Use
212
+
213
+ *List how the model may foreseeably be misused and address what users ought not to do with the model.*
214
+ -->
215
+
216
+ ## Evaluation
217
+
218
+ ### Metrics
219
+
220
+ #### Sparse Information Retrieval
221
+
222
+ * Dataset: `NanoNFCorpus`
223
+ * Evaluated with [<code>SparseInformationRetrievalEvaluator</code>](https://sbert.net/docs/package_reference/sparse_encoder/evaluation.html#sentence_transformers.sparse_encoder.evaluation.SparseInformationRetrievalEvaluator)
224
+
225
+ | Metric | Value |
226
+ |:----------------------|:-----------|
227
+ | dot_accuracy@1 | 0.02 |
228
+ | dot_accuracy@3 | 0.08 |
229
+ | dot_accuracy@5 | 0.08 |
230
+ | dot_accuracy@10 | 0.12 |
231
+ | dot_precision@1 | 0.02 |
232
+ | dot_precision@3 | 0.0333 |
233
+ | dot_precision@5 | 0.032 |
234
+ | dot_precision@10 | 0.026 |
235
+ | dot_recall@1 | 0.0001 |
236
+ | dot_recall@3 | 0.0033 |
237
+ | dot_recall@5 | 0.0045 |
238
+ | dot_recall@10 | 0.0063 |
239
+ | **dot_ndcg@10** | **0.0272** |
240
+ | dot_mrr@10 | 0.0509 |
241
+ | dot_map@100 | 0.0067 |
242
+ | query_active_dims | 51200.0 |
243
+ | query_sparsity_ratio | 0.0 |
244
+ | corpus_active_dims | 51200.0 |
245
+ | corpus_sparsity_ratio | 0.0 |
246
+ | avg_flops | 51200.0 |
247
+
248
+ <!--
249
+ ## Bias, Risks and Limitations
250
+
251
+ *What are the known or foreseeable issues stemming from this model? You could also flag here known failure cases or weaknesses of the model.*
252
+ -->
253
+
254
+ <!--
255
+ ### Recommendations
256
+
257
+ *What are recommendations with respect to the foreseeable issues? For example, filtering explicit content.*
258
+ -->
259
+
260
+ ## Training Details
261
+
262
+ ### Training Dataset
263
+
264
+ #### Unnamed Dataset
265
+
266
+ * Size: 333,547 training samples
267
+ * Columns: <code>anchor</code> and <code>positive</code>
268
+ * Approximate statistics based on the first 1000 samples:
269
+ | | anchor | positive |
270
+ |:--------|:-----------------------------------------------------------------------------------|:-------------------------------------------------------------------------------------|
271
+ | type | string | string |
272
+ | details | <ul><li>min: 3 tokens</li><li>mean: 22.81 tokens</li><li>max: 517 tokens</li></ul> | <ul><li>min: 1 tokens</li><li>mean: 406.29 tokens</li><li>max: 4096 tokens</li></ul> |
273
+ * Samples:
274
+ | anchor | positive |
275
+ |:---------------------------------------------------------------------|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
276
+ | <code><br>Hun er mye eldre enn henne.<br> <br> </code> | <code><br>Hun er meget ældre end hende.<br> <br> </code> |
277
+ | <code><br>Hva så? <br> <br>Du lå med kona mi!<br> <br> </code> | <code><br>Men du gik i seng med min kone.<br> <br> </code> |
278
+ | <code>Hur aktiverar jag en indeksfond?</code> | <code>Att investera i indexfonder är ett populärt sätt att exponera sig mot aktiemarknaden. Det är ett passivt investeringsalternativ där portföljen följer en specifik index, till exempel OMX Stockholm 30.<br><br>För att aktivera en indexfond behöver du ett depåkonto hos en bank eller en investmentsmäklare. Innan du påbörjar processen bör du noggrant undersöka och jämföra olika fonder för att hitta den som bäst passar dina investeringsmål och risktolerans.<br><br>När du väl har valt en fond kan du vanligtvis aktivera den online via bankens eller mäklarens plattform. Du behöver ange hur mycket du vill investera och godkänna villkoren. Därefter kommer fonden att köpas och lagts till i ditt depåkonto.<br><br>Det är viktigt att ha en långsiktig investeringshorisont när du investerar i indexfonder. Marknaderna fluktuerar i värde på kort sikt, men över tid har indexfonder historiskt sett genererat goda avkastningar.</code> |
279
+ * Loss: [<code>SpladeLoss</code>](https://sbert.net/docs/package_reference/sparse_encoder/losses.html#spladeloss) with these parameters:
280
+ ```json
281
+ {
282
+ "loss": "SparseMultipleNegativesRankingLoss(scale=1.0, similarity_fct='dot_score', gather_across_devices=False)",
283
+ "document_regularizer_weight": 0.003,
284
+ "query_regularizer_weight": 0.0001
285
+ }
286
+ ```
287
+
288
+ ### Evaluation Dataset
289
+
290
+ #### Unnamed Dataset
291
+
292
+ * Size: 14,458 evaluation samples
293
+ * Columns: <code>anchor</code> and <code>positive</code>
294
+ * Approximate statistics based on the first 1000 samples:
295
+ | | anchor | positive |
296
+ |:--------|:----------------------------------------------------------------------------------|:-------------------------------------------------------------------------------------|
297
+ | type | string | string |
298
+ | details | <ul><li>min: 3 tokens</li><li>mean: 16.03 tokens</li><li>max: 86 tokens</li></ul> | <ul><li>min: 7 tokens</li><li>mean: 134.75 tokens</li><li>max: 4096 tokens</li></ul> |
299
+ * Samples:
300
+ | anchor | positive |
301
+ |:--------------------------------------------------------------------------------------------|:------------------------------------------------------------------------------------|
302
+ | <code><br> <br>Hva er det for organisasjon som skal ha årsmøte her?<br> <br> </code> | <code><br> <br>Hvilken organisation skal holde kongres her?<br> <br> </code> |
303
+ | <code><br>Livet ditt er jo ikke så verst.<br> <br> </code> | <code><br>Dit liv er ikke så slemt.<br> <br> </code> |
304
+ | <code><br> <br>Men du må ta deg av dem for meg, okay?<br> <br> </code> | <code><br> <br>Men du må tage dig af dem for mig, okay?<br> <br> </code> |
305
+ * Loss: [<code>SpladeLoss</code>](https://sbert.net/docs/package_reference/sparse_encoder/losses.html#spladeloss) with these parameters:
306
+ ```json
307
+ {
308
+ "loss": "SparseMultipleNegativesRankingLoss(scale=1.0, similarity_fct='dot_score', gather_across_devices=False)",
309
+ "document_regularizer_weight": 0.003,
310
+ "query_regularizer_weight": 0.0001
311
+ }
312
+ ```
313
+
314
+ ### Training Hyperparameters
315
+ #### Non-Default Hyperparameters
316
+
317
+ - `eval_strategy`: steps
318
+ - `per_device_train_batch_size`: 16
319
+ - `per_device_eval_batch_size`: 32
320
+ - `learning_rate`: 2e-05
321
+ - `weight_decay`: 0.01
322
+ - `num_train_epochs`: 1
323
+ - `warmup_ratio`: 0.1
324
+ - `bf16`: True
325
+ - `dataloader_num_workers`: 2
326
+ - `dataloader_prefetch_factor`: 2
327
+ - `load_best_model_at_end`: True
328
+ - `ddp_find_unused_parameters`: True
329
+ - `push_to_hub`: True
330
+ - `hub_model_id`: thivy/norbert4-base-splade-retrieval
331
+ - `hub_strategy`: checkpoint
332
+ - `hub_private_repo`: False
333
+ - `gradient_checkpointing`: True
334
+ - `gradient_checkpointing_kwargs`: {'use_reentrant': False}
335
+ - `multi_dataset_batch_sampler`: round_robin
336
+
337
+ #### All Hyperparameters
338
+ <details><summary>Click to expand</summary>
339
+
340
+ - `overwrite_output_dir`: False
341
+ - `do_predict`: False
342
+ - `eval_strategy`: steps
343
+ - `prediction_loss_only`: True
344
+ - `per_device_train_batch_size`: 16
345
+ - `per_device_eval_batch_size`: 32
346
+ - `per_gpu_train_batch_size`: None
347
+ - `per_gpu_eval_batch_size`: None
348
+ - `gradient_accumulation_steps`: 1
349
+ - `eval_accumulation_steps`: None
350
+ - `torch_empty_cache_steps`: None
351
+ - `learning_rate`: 2e-05
352
+ - `weight_decay`: 0.01
353
+ - `adam_beta1`: 0.9
354
+ - `adam_beta2`: 0.999
355
+ - `adam_epsilon`: 1e-08
356
+ - `max_grad_norm`: 1.0
357
+ - `num_train_epochs`: 1
358
+ - `max_steps`: -1
359
+ - `lr_scheduler_type`: linear
360
+ - `lr_scheduler_kwargs`: {}
361
+ - `warmup_ratio`: 0.1
362
+ - `warmup_steps`: 0
363
+ - `log_level`: passive
364
+ - `log_level_replica`: warning
365
+ - `log_on_each_node`: True
366
+ - `logging_nan_inf_filter`: True
367
+ - `save_safetensors`: True
368
+ - `save_on_each_node`: False
369
+ - `save_only_model`: False
370
+ - `restore_callback_states_from_checkpoint`: False
371
+ - `no_cuda`: False
372
+ - `use_cpu`: False
373
+ - `use_mps_device`: False
374
+ - `seed`: 42
375
+ - `data_seed`: None
376
+ - `jit_mode_eval`: False
377
+ - `bf16`: True
378
+ - `fp16`: False
379
+ - `fp16_opt_level`: O1
380
+ - `half_precision_backend`: auto
381
+ - `bf16_full_eval`: False
382
+ - `fp16_full_eval`: False
383
+ - `tf32`: None
384
+ - `local_rank`: 0
385
+ - `ddp_backend`: None
386
+ - `tpu_num_cores`: None
387
+ - `tpu_metrics_debug`: False
388
+ - `debug`: []
389
+ - `dataloader_drop_last`: True
390
+ - `dataloader_num_workers`: 2
391
+ - `dataloader_prefetch_factor`: 2
392
+ - `past_index`: -1
393
+ - `disable_tqdm`: False
394
+ - `remove_unused_columns`: True
395
+ - `label_names`: None
396
+ - `load_best_model_at_end`: True
397
+ - `ignore_data_skip`: False
398
+ - `fsdp`: []
399
+ - `fsdp_min_num_params`: 0
400
+ - `fsdp_config`: {'min_num_params': 0, 'xla': False, 'xla_fsdp_v2': False, 'xla_fsdp_grad_ckpt': False}
401
+ - `fsdp_transformer_layer_cls_to_wrap`: None
402
+ - `accelerator_config`: {'split_batches': False, 'dispatch_batches': None, 'even_batches': True, 'use_seedable_sampler': True, 'non_blocking': False, 'gradient_accumulation_kwargs': None}
403
+ - `parallelism_config`: None
404
+ - `deepspeed`: None
405
+ - `label_smoothing_factor`: 0.0
406
+ - `optim`: adamw_torch_fused
407
+ - `optim_args`: None
408
+ - `adafactor`: False
409
+ - `group_by_length`: False
410
+ - `length_column_name`: length
411
+ - `project`: huggingface
412
+ - `trackio_space_id`: trackio
413
+ - `ddp_find_unused_parameters`: True
414
+ - `ddp_bucket_cap_mb`: None
415
+ - `ddp_broadcast_buffers`: False
416
+ - `dataloader_pin_memory`: True
417
+ - `dataloader_persistent_workers`: False
418
+ - `skip_memory_metrics`: True
419
+ - `use_legacy_prediction_loop`: False
420
+ - `push_to_hub`: True
421
+ - `resume_from_checkpoint`: None
422
+ - `hub_model_id`: thivy/norbert4-base-splade-retrieval
423
+ - `hub_strategy`: checkpoint
424
+ - `hub_private_repo`: False
425
+ - `hub_always_push`: False
426
+ - `hub_revision`: None
427
+ - `gradient_checkpointing`: True
428
+ - `gradient_checkpointing_kwargs`: {'use_reentrant': False}
429
+ - `include_inputs_for_metrics`: False
430
+ - `include_for_metrics`: []
431
+ - `eval_do_concat_batches`: True
432
+ - `fp16_backend`: auto
433
+ - `push_to_hub_model_id`: None
434
+ - `push_to_hub_organization`: None
435
+ - `mp_parameters`:
436
+ - `auto_find_batch_size`: False
437
+ - `full_determinism`: False
438
+ - `torchdynamo`: None
439
+ - `ray_scope`: last
440
+ - `ddp_timeout`: 1800
441
+ - `torch_compile`: False
442
+ - `torch_compile_backend`: None
443
+ - `torch_compile_mode`: None
444
+ - `include_tokens_per_second`: False
445
+ - `include_num_input_tokens_seen`: no
446
+ - `neftune_noise_alpha`: None
447
+ - `optim_target_modules`: None
448
+ - `batch_eval_metrics`: False
449
+ - `eval_on_start`: False
450
+ - `use_liger_kernel`: False
451
+ - `liger_kernel_config`: None
452
+ - `eval_use_gather_object`: False
453
+ - `average_tokens_across_devices`: True
454
+ - `prompts`: None
455
+ - `batch_sampler`: batch_sampler
456
+ - `multi_dataset_batch_sampler`: round_robin
457
+ - `router_mapping`: {}
458
+ - `learning_rate_mapping`: {}
459
+
460
+ </details>
461
+
462
+ ### Training Logs
463
+ | Epoch | Step | Training Loss | Validation Loss | NanoNFCorpus_dot_ndcg@10 |
464
+ |:------:|:----:|:-------------:|:---------------:|:------------------------:|
465
+ | 0.0048 | 50 | 37895.69 | - | - |
466
+ | 0.0096 | 100 | 10002.0562 | - | - |
467
+ | 0.0144 | 150 | 3805.4731 | - | - |
468
+ | 0.0192 | 200 | 923.0944 | - | - |
469
+ | 0.0240 | 250 | 514.7795 | - | - |
470
+ | 0.0288 | 300 | 284.5449 | - | - |
471
+ | 0.0336 | 350 | 90.0678 | - | - |
472
+ | 0.0384 | 400 | 30.8482 | - | - |
473
+ | 0.0432 | 450 | 2.5071 | - | - |
474
+ | 0.0480 | 500 | 1.3525 | 2.2663 | 0.0272 |
475
+
476
+
477
+ ### Framework Versions
478
+ - Python: 3.12.12
479
+ - Sentence Transformers: 5.2.0
480
+ - Transformers: 4.57.3
481
+ - PyTorch: 2.9.1+cu128
482
+ - Accelerate: 1.12.0
483
+ - Datasets: 4.4.2
484
+ - Tokenizers: 0.22.2
485
+
486
+ ## Citation
487
+
488
+ ### BibTeX
489
+
490
+ #### Sentence Transformers
491
+ ```bibtex
492
+ @inproceedings{reimers-2019-sentence-bert,
493
+ title = "Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks",
494
+ author = "Reimers, Nils and Gurevych, Iryna",
495
+ booktitle = "Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing",
496
+ month = "11",
497
+ year = "2019",
498
+ publisher = "Association for Computational Linguistics",
499
+ url = "https://arxiv.org/abs/1908.10084",
500
+ }
501
+ ```
502
+
503
+ #### SpladeLoss
504
+ ```bibtex
505
+ @misc{formal2022distillationhardnegativesampling,
506
+ title={From Distillation to Hard Negative Sampling: Making Sparse Neural IR Models More Effective},
507
+ author={Thibault Formal and Carlos Lassance and Benjamin Piwowarski and Stéphane Clinchant},
508
+ year={2022},
509
+ eprint={2205.04733},
510
+ archivePrefix={arXiv},
511
+ primaryClass={cs.IR},
512
+ url={https://arxiv.org/abs/2205.04733},
513
+ }
514
+ ```
515
+
516
+ #### SparseMultipleNegativesRankingLoss
517
+ ```bibtex
518
+ @misc{henderson2017efficient,
519
+ title={Efficient Natural Language Response Suggestion for Smart Reply},
520
+ author={Matthew Henderson and Rami Al-Rfou and Brian Strope and Yun-hsuan Sung and Laszlo Lukacs and Ruiqi Guo and Sanjiv Kumar and Balint Miklos and Ray Kurzweil},
521
+ year={2017},
522
+ eprint={1705.00652},
523
+ archivePrefix={arXiv},
524
+ primaryClass={cs.CL}
525
+ }
526
+ ```
527
+
528
+ #### FlopsLoss
529
+ ```bibtex
530
+ @article{paria2020minimizing,
531
+ title={Minimizing flops to learn efficient sparse representations},
532
+ author={Paria, Biswajit and Yeh, Chih-Kuan and Yen, Ian EH and Xu, Ning and Ravikumar, Pradeep and P{'o}czos, Barnab{'a}s},
533
+ journal={arXiv preprint arXiv:2004.05665},
534
+ year={2020}
535
+ }
536
+ ```
537
+
538
+ <!--
539
+ ## Glossary
540
+
541
+ *Clearly define terms in order to be accessible across audiences.*
542
+ -->
543
+
544
+ <!--
545
+ ## Model Card Authors
546
+
547
+ *Lists the people who create the model card, providing recognition and accountability for the detailed work that goes into its construction.*
548
+ -->
549
+
550
+ <!--
551
+ ## Model Card Contact
552
+
553
+ *Provides a way for people who have updates to the Model Card, suggestions, or questions, to contact the Model Card authors.*
554
+ -->
last-checkpoint/config.json ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "GptBertForMaskedLM"
4
+ ],
5
+ "attention_dropout": 0.0,
6
+ "attn_implementation": null,
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_gptbert.GptBertConfig",
9
+ "AutoModel": "modeling_gptbert.GptBertModel",
10
+ "AutoModelForCausalLM": "modeling_gptbert.GptBertForCausalLM",
11
+ "AutoModelForMaskedLM": "modeling_gptbert.GptBertForMaskedLM",
12
+ "AutoModelForMultipleChoice": "modeling_gptbert.GptBertForMultipleChoice",
13
+ "AutoModelForQuestionAnswering": "modeling_gptbert.GptBertForQuestionAnswering",
14
+ "AutoModelForSequenceClassification": "modeling_gptbert.GptBertForSequenceClassification",
15
+ "AutoModelForTokenClassification": "modeling_gptbert.GptBertForTokenClassification"
16
+ },
17
+ "bos_token_id": 1,
18
+ "classifier_dropout": 0.2,
19
+ "deterministic_flash_attn": false,
20
+ "dtype": "float32",
21
+ "embedding_dropout": 0.1,
22
+ "eos_token_id": 2,
23
+ "global_window_length": 8192,
24
+ "hidden_dropout": 0.0,
25
+ "hidden_size": 640,
26
+ "intermediate_size": 1664,
27
+ "layer_norm_eps": 1e-07,
28
+ "local_global_ratio": 4,
29
+ "local_window_length": 256,
30
+ "mask_token_id": 4,
31
+ "max_sequence_length": 16384,
32
+ "model": "norbert4",
33
+ "num_attention_heads": 10,
34
+ "num_layers": 24,
35
+ "pad_token_id": 3,
36
+ "query_key_head_size": 64,
37
+ "rope_theta": 160000,
38
+ "transformers_version": "4.57.3",
39
+ "unk_token_id": 0,
40
+ "use_cache": false,
41
+ "value_head_size": 64,
42
+ "vocab_size": 51200
43
+ }
last-checkpoint/config_sentence_transformers.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "SparseEncoder",
3
+ "__version__": {
4
+ "sentence_transformers": "5.2.0",
5
+ "transformers": "4.57.3",
6
+ "pytorch": "2.9.1+cu128"
7
+ },
8
+ "prompts": {
9
+ "query": "",
10
+ "document": ""
11
+ },
12
+ "default_prompt_name": null,
13
+ "similarity_fn_name": "dot"
14
+ }
last-checkpoint/configuration_gptbert.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from pathlib import Path
5
+ import copy
6
+ from transformers.configuration_utils import PretrainedConfig
7
+
8
+
9
+ class GptBertConfig(PretrainedConfig):
10
+
11
+ def __init__(
12
+ self,
13
+ config_file: Path | str | None = None,
14
+ **kwargs
15
+ ):
16
+ super().__init__(**kwargs)
17
+ self.model = "norbert4"
18
+
19
+ if config_file is not None:
20
+ if type(config_file) is str:
21
+ config_file = Path(config_file)
22
+ assert type(config_file) is not Path, "The config_file should either be a Path or str"
23
+ with config_file.open("r") as file:
24
+ config = json.load(file)
25
+
26
+ for attr, value in config.items():
27
+ if isinstance(value, str):
28
+ value = value.lower()
29
+ setattr(self, attr, value)
30
+
31
+ for attr, value in kwargs.items():
32
+ if isinstance(value, str):
33
+ value = value.lower()
34
+ setattr(self, attr, value)
last-checkpoint/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ca6b59b6342fcd6a1910b237e5db7707f98673239940cfc25a5d1876082ebc33
3
+ size 728561776
last-checkpoint/modeling_gptbert.py ADDED
@@ -0,0 +1,1105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.nn import functional as F
6
+ from torch import _softmax_backward_data as _softmax_backward_data
7
+
8
+ from functools import partial, lru_cache
9
+
10
+ from .configuration_gptbert import GptBertConfig
11
+ from transformers.modeling_utils import PreTrainedModel
12
+ from transformers.activations import gelu_new
13
+ from transformers.utils import is_flash_attn_2_available, logging
14
+ from transformers.modeling_outputs import (
15
+ MaskedLMOutput,
16
+ MultipleChoiceModelOutput,
17
+ QuestionAnsweringModelOutput,
18
+ SequenceClassifierOutput,
19
+ TokenClassifierOutput,
20
+ BaseModelOutput,
21
+ CausalLMOutput
22
+ )
23
+ import math
24
+ from typing import TYPE_CHECKING, Optional, Union, Tuple, List
25
+
26
+ logger = logging.get_logger(__name__)
27
+
28
+
29
+ # Workaround for transformers < 4.36.0 check_imports issue
30
+ # See: https://github.com/huggingface/transformers/issues/28459
31
+ try:
32
+ if is_flash_attn_2_available():
33
+ from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
34
+ from flash_attn.layers.rotary import RotaryEmbedding
35
+ from flash_attn.ops.triton.rotary import apply_rotary
36
+ else:
37
+ flash_attn_varlen_qkvpacked_func, RotaryEmbedding, apply_rotary = None, object, None
38
+ logger.warning_once(
39
+ "NorBERT4 støtter FlashAttention, men det er ikke funnet i miljøet ditt. Du bør vurdere å oppdatere miljøet ditt for å få raskere og mindre minnekrevende behandling."
40
+ )
41
+ except ImportError:
42
+ flash_attn_varlen_qkvpacked_func, RotaryEmbedding, apply_rotary = None, object, None
43
+ logger.warning_once(
44
+ "NorBERT4 støtter FlashAttention, men det er ikke funnet i miljøet ditt. Du bør vurdere å oppdatere miljøet ditt for å få raskere og mindre minnekrevende behandling."
45
+ )
46
+
47
+
48
+ # from https://github.com/huggingface/transformers/blob/main/src/transformers/models/modernbert/modeling_modernbert.py
49
+ @torch.compiler.disable()
50
+ def _unpad_input(input_ids: torch.Tensor, attention_mask: torch.Tensor):
51
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
52
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
53
+ max_seqlen_in_batch = int(seqlens_in_batch.max().item())
54
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
55
+
56
+ if input_ids.dim() == 2:
57
+ unpadded_inputs = input_ids.flatten()[indices]
58
+ else:
59
+ batch_size, sequence_length, *rest = input_ids.shape
60
+ shape = batch_size * sequence_length
61
+ unpadded_inputs = input_ids.view(shape, *rest)[indices]
62
+
63
+ return unpadded_inputs, indices, cu_seqlens, max_seqlen_in_batch
64
+
65
+
66
+ # from https://github.com/huggingface/transformers/blob/main/src/transformers/models/modernbert/modeling_modernbert.py
67
+ def _pad_output(input_ids: torch.Tensor, indices: torch.Tensor, batch_size: int, sequence_length: int) -> torch.Tensor:
68
+ if input_ids.dim() == 1:
69
+ output = torch.zeros(batch_size * sequence_length, dtype=input_ids.dtype, device=input_ids.device)
70
+ output[indices] = input_ids
71
+ padded_inputs = output.view(batch_size, sequence_length)
72
+ else:
73
+ _, *rest = input_ids.shape
74
+ output = torch.zeros(batch_size * sequence_length, *rest, dtype=input_ids.dtype, device=input_ids.device)
75
+ output[indices] = input_ids
76
+ padded_inputs = output.view(batch_size, sequence_length, *rest)
77
+
78
+ return padded_inputs
79
+
80
+
81
+ class CastedLinear(nn.Linear):
82
+ def __init__(self, in_features, out_features, bias):
83
+ super().__init__(in_features, out_features, bias=bias)
84
+
85
+ def forward(self, x):
86
+ return F.linear(x, self.weight.type_as(x), bias=self.bias.type_as(x) if self.bias is not None else None)
87
+
88
+
89
+ class CastedLinearIn(nn.Linear):
90
+ def __init__(self, in_features, out_features, bias):
91
+ super().__init__(in_features, out_features, bias=bias)
92
+ self.scale = nn.Parameter(torch.ones(in_features))
93
+
94
+ def forward(self, x):
95
+ return F.linear(x, (self.weight * (self.scale + 1.0).unsqueeze(0)).type_as(x), bias=self.bias.type_as(x) if self.bias is not None else None)
96
+
97
+
98
+ class MultiCastedLinearOrthoIn(nn.Module):
99
+ def __init__(self, in_features, out_features, bias):
100
+ super().__init__()
101
+
102
+ self.in_features = in_features
103
+ self.out_features = out_features
104
+
105
+ self.weights = nn.ParameterList()
106
+ for out_feature in out_features:
107
+ self.weights.append(nn.Parameter(torch.empty((out_feature, in_features))))
108
+
109
+ if bias:
110
+ self.bias = nn.Parameter(torch.zeros(sum(out_features)))
111
+ else:
112
+ self.bias = self.register_parameter("bias", None)
113
+
114
+ self.scale = nn.Parameter(torch.ones(in_features))
115
+
116
+ def forward(self, x):
117
+ return F.linear(x, (torch.cat([weight for weight in self.weights], dim=0) * (self.scale + 1.0).unsqueeze(0)).type_as(x), bias=self.bias.type_as(x) if self.bias is not None else None)
118
+
119
+
120
+ class GeGLU(nn.Module):
121
+ def forward(self, x):
122
+ x, gate = x.chunk(2, dim=-1)
123
+ return x * gelu_new(gate)
124
+
125
+
126
+ class Embedding(nn.Module):
127
+ def __init__(self, config: GptBertConfig):
128
+ super().__init__()
129
+
130
+ self.word_embedding = nn.Embedding(config.vocab_size, config.hidden_size)
131
+ self.word_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False, bias=False)
132
+ self.word_scale = nn.Parameter(torch.zeros(config.hidden_size))
133
+ self.dropout = nn.Dropout(config.embedding_dropout)
134
+
135
+ def forward(self, input_ids: torch.Tensor):
136
+ word_embedding = self.word_embedding(input_ids)
137
+ word_embedding = self.word_norm(word_embedding)
138
+ word_embedding = word_embedding * (self.word_scale + 1.0)
139
+
140
+ return self.dropout(word_embedding)
141
+
142
+
143
+ class LMClassifier(nn.Module):
144
+ def __init__(self, config: GptBertConfig, n_labels: int):
145
+ super().__init__()
146
+
147
+ self.pre_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False)
148
+ self.projection = CastedLinearIn(config.hidden_size, config.hidden_size, bias=False)
149
+ self.post_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False)
150
+ self.emb2vocab = CastedLinearIn(config.hidden_size, n_labels, bias=True)
151
+
152
+ def forward(self, x: torch.Tensor):
153
+ x = self.pre_norm(x.float()).type_as(x)
154
+ x = self.projection(x)
155
+ x = gelu_new(x)
156
+ x = self.post_norm(x.float()).type_as(x)
157
+ x = self.emb2vocab(x)
158
+ return x
159
+
160
+
161
+ class Classifier(nn.Module):
162
+ def __init__(self, config: GptBertConfig, n_labels: int):
163
+ super().__init__()
164
+
165
+ self.pre_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False)
166
+ self.projection = CastedLinearIn(config.hidden_size, config.hidden_size, bias=False)
167
+ self.post_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False)
168
+ self.dropout = nn.Dropout(config.classifier_dropout)
169
+ self.output_projection = CastedLinearIn(config.hidden_size, n_labels, bias=True)
170
+
171
+ def forward(self, x: torch.Tensor):
172
+ x = self.pre_norm(x.float()).type_as(x)
173
+ x = self.projection(x)
174
+ x = gelu_new(x)
175
+ x = self.post_norm(x.float()).type_as(x)
176
+ x = self.dropout(x)
177
+ x = self.output_projection(x)
178
+ return x
179
+
180
+
181
+ # from https://github.com/huggingface/transformers/blob/main/src/transformers/models/modernbert/modeling_modernbert.py
182
+ def flash_attention_forward(qkv: torch.Tensor, rotary_emb: UnpaddedRotaryEmbedding, cu_seqlens: torch.Tensor, max_seqlen: int, causal: bool, local_attention: Tuple[int, int], dropout_p: float, deterministic: bool, target_dtype: torch.dtype = torch.bfloat16, **_kwargs):
183
+ qkv = rotary_emb(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)
184
+
185
+ convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16)
186
+ if convert_dtype:
187
+ # FA2 implementation only supports fp16 and bf16. If FA2 is supported,
188
+ # bfloat16 must be supported as of FA2 2.5.7. (Turing GPUs not supported)
189
+ orig_dtype = qkv.dtype
190
+ qkv = qkv.to(target_dtype)
191
+
192
+ attn = flash_attn_varlen_qkvpacked_func(
193
+ qkv,
194
+ cu_seqlens=cu_seqlens,
195
+ max_seqlen=max_seqlen,
196
+ dropout_p=dropout_p,
197
+ deterministic=deterministic,
198
+ window_size=local_attention,
199
+ causal=False
200
+ )
201
+ attn = attn.to(orig_dtype) # type: ignore
202
+ else:
203
+ attn = flash_attn_varlen_qkvpacked_func(
204
+ qkv,
205
+ cu_seqlens=cu_seqlens,
206
+ max_seqlen=max_seqlen,
207
+ dropout_p=dropout_p,
208
+ deterministic=deterministic,
209
+ window_size=local_attention,
210
+ causal=False
211
+ )
212
+ return attn
213
+
214
+
215
+ # from https://github.com/huggingface/transformers/blob/main/src/transformers/models/modernbert/modeling_modernbert.py
216
+ class ApplyRotaryEmbUnpad(torch.autograd.Function):
217
+ @staticmethod
218
+ def forward(ctx, qkv, cos, sin, cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None):
219
+ # (total_nnz, 3, nheads, headdim)
220
+ qkv = qkv.contiguous()
221
+ total_nnz, _three, _nheads, headdim = qkv.shape
222
+ # We need qkv to be contiguous so that when we reshape to combine (3, nheads) dimensions,
223
+ # we get the same tensor
224
+ # qk = rearrange(qkv[:, :2], "b_s t h d -> b_s (t h) d")
225
+ qk = qkv[:, :2].view(total_nnz, -1, headdim)
226
+ apply_rotary(qk, cos, sin, seqlen_offsets=0, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, interleaved=False, inplace=True)
227
+
228
+ ctx.save_for_backward(cos, sin, cu_seqlens)
229
+ ctx.max_seqlen = max_seqlen
230
+ return qkv
231
+
232
+ @staticmethod
233
+ def backward(ctx, do):
234
+ cos, sin, cu_seqlens = ctx.saved_tensors
235
+ do = do.contiguous()
236
+ total_nnz, _three, _nheads, headdim = do.shape
237
+ # We need dqkv to be contiguous so that when we reshape to combine (3, nheads) dimensions,
238
+ # we get the same tensor
239
+ dqk = do[:, :2].view(total_nnz, -1, headdim)
240
+ apply_rotary(
241
+ dqk,
242
+ cos,
243
+ sin,
244
+ seqlen_offsets=0,
245
+ cu_seqlens=cu_seqlens,
246
+ max_seqlen=ctx.max_seqlen,
247
+ interleaved=False,
248
+ inplace=True,
249
+ conjugate=True,
250
+ )
251
+
252
+ return do, None, None, None, None, None, None
253
+
254
+
255
+ # from https://github.com/huggingface/transformers/blob/main/src/transformers/models/modernbert/modeling_modernbert.py
256
+ def apply_rotary_unpadded(qkv, cos, sin, cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None):
257
+ return ApplyRotaryEmbUnpad.apply(qkv, cos, sin, cu_seqlens, max_seqlen)
258
+
259
+
260
+ # from https://github.com/huggingface/transformers/blob/main/src/transformers/models/modernbert/modeling_modernbert.py
261
+ class UnpaddedRotaryEmbedding(RotaryEmbedding):
262
+ def __init__(self, dim: int, base: float = 10000.0, max_seqlen: Optional[int] = None):
263
+ super().__init__(dim=dim, base=base, device=None, interleaved=False)
264
+ self.max_seqlen = max_seqlen
265
+
266
+ def forward(self, qkv: torch.Tensor, cu_seqlens: torch.Tensor, max_seqlen: Optional[int] = None) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
267
+ if max_seqlen is not None:
268
+ self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
269
+
270
+ qkv = apply_rotary_unpadded(
271
+ qkv,
272
+ self._cos_cached,
273
+ self._sin_cached,
274
+ cu_seqlens=cu_seqlens,
275
+ max_seqlen=max_seqlen,
276
+ )
277
+
278
+ return qkv
279
+
280
+
281
+ class RotaryPositionalEmbeddings(nn.Module):
282
+ def __init__(self, config, theta: int):
283
+ super().__init__()
284
+
285
+ head_size = config.query_key_head_size
286
+ assert head_size % 2 == 0
287
+ max_seq_len = config.max_sequence_length
288
+
289
+ inv_freq = 1.0 / (theta ** (torch.arange(0, head_size, 2, dtype=torch.float32) / head_size))
290
+ pos = torch.arange(max_seq_len, dtype=torch.float32)
291
+ embedding = torch.einsum('n, d -> nd', pos, inv_freq)
292
+ embedding = torch.cat([embedding, embedding], dim=-1).unsqueeze(0)
293
+ self.register_buffer("cos_matrix", embedding.cos(), persistent=False)
294
+ self.register_buffer("sin_matrix", embedding.sin(), persistent=False)
295
+
296
+ def forward(self, x: torch.Tensor):
297
+ hidden_layer = x.float()
298
+
299
+ seq_len = x.shape[2]
300
+
301
+ cos_matrix = self.cos_matrix[:, None, :seq_len, :]
302
+ sin_matrix = self.sin_matrix[:, None, :seq_len, :]
303
+
304
+ x_rotate_half = torch.cat(
305
+ [
306
+ -hidden_layer[:, :, :, x.size(-1) // 2:],
307
+ hidden_layer[:, :, :, :x.size(-1) // 2]
308
+ ],
309
+ dim=-1
310
+ )
311
+
312
+ out = hidden_layer * cos_matrix + x_rotate_half * sin_matrix
313
+ return out.type_as(x)
314
+
315
+
316
+ class MaskedSoftmax(torch.autograd.Function):
317
+ @staticmethod
318
+ def forward(ctx, x: torch.Tensor, mask: torch.BoolTensor, dim: int) -> torch.Tensor:
319
+ ctx.dim = dim
320
+ x.masked_fill_(mask, float('-inf'))
321
+ x = torch.softmax(x, ctx.dim)
322
+ x.masked_fill_(mask, 0.0)
323
+ ctx.save_for_backward(x)
324
+ return x
325
+
326
+ @staticmethod
327
+ def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None, None]:
328
+ output: torch.Tensor
329
+
330
+ output, = ctx.saved_tensors
331
+ inputGrad: torch.Tensor = _softmax_backward_data(grad_output, output, ctx.dim, output.dtype)
332
+ return inputGrad, None, None
333
+
334
+
335
+ class SelfAttention(nn.Module):
336
+ def __init__(self, config: GptBertConfig, layer_idx: int):
337
+ super().__init__()
338
+
339
+ self.config = config
340
+ self.layer_idx = layer_idx
341
+
342
+ self.d_qk = config.query_key_head_size
343
+ self.d_v = config.value_head_size
344
+ self.num_attention_heads = config.num_attention_heads
345
+ self.num_kv_heads = config.num_attention_heads
346
+ self.hidden_size = config.hidden_size
347
+
348
+ self.q_out_dim = self.d_qk * self.num_attention_heads
349
+ self.k_out_dim = self.d_qk * self.num_kv_heads
350
+ self.v_out_dim = self.d_v * self.num_kv_heads
351
+
352
+ self.qk_proj = MultiCastedLinearOrthoIn(self.hidden_size, [self.q_out_dim, self.k_out_dim], bias=False)
353
+ self.v_proj = CastedLinearIn(self.hidden_size, self.v_out_dim, bias=False)
354
+ self.out_proj = CastedLinearIn(self.d_v*self.num_attention_heads, self.hidden_size, bias=False)
355
+
356
+ self.pre_v_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False)
357
+ self.pre_qk_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False)
358
+ self.inter_norm = nn.LayerNorm(self.d_v * self.num_attention_heads, eps=config.layer_norm_eps, elementwise_affine=False)
359
+ self.q_norm = nn.LayerNorm(self.d_qk, eps=config.layer_norm_eps, elementwise_affine=False, bias=False)
360
+ self.k_norm = nn.LayerNorm(self.d_qk, eps=config.layer_norm_eps, elementwise_affine=False, bias=False)
361
+ self.k_scale = nn.Parameter(torch.ones(self.num_kv_heads, self.d_qk))
362
+ self.q_scale = nn.Parameter(torch.ones(self.num_attention_heads, self.d_qk))
363
+
364
+ self.attention_dropout = nn.Dropout(config.attention_dropout)
365
+ self.dropout = nn.Dropout(config.hidden_dropout)
366
+
367
+ theta = 160_000 if (layer_idx + 1) % config.local_global_ratio == 0 else 10_000
368
+
369
+ # Initialize rotary embeddings based on whether FlashAttention is available
370
+ if flash_attn_varlen_qkvpacked_func is not None:
371
+ self.rope_embedding = UnpaddedRotaryEmbedding(dim=self.d_qk, base=theta, max_seqlen=config.max_sequence_length)
372
+ else:
373
+ self.rope_embedding = RotaryPositionalEmbeddings(config, theta)
374
+
375
+ self.scale = 1.0 / math.sqrt(self.d_qk)
376
+ self.lambdas = nn.Parameter(torch.tensor([0.5]))
377
+
378
+ self.sequence_length = config.max_sequence_length
379
+ self.is_causal = config.is_decoder
380
+ self.window_length = None
381
+
382
+ def set_window_length(self, window_length: int):
383
+ self.window_length = window_length
384
+
385
+ def _get_window_mask(self, query_length: int, key_length: int, device: torch.device):
386
+ """Create and cache window attention mask."""
387
+ if self.is_causal:
388
+ mask = torch.ones(query_length, key_length, dtype=torch.bool, device=device)
389
+ mask = mask.tril().triu(diagonal=-self.window_length)
390
+ else:
391
+ mask = torch.ones(query_length, key_length, dtype=torch.bool, device=device)
392
+ mask = mask.tril(diagonal=self.window_length).triu(diagonal=-self.window_length)
393
+ return mask.view(1, 1, query_length, key_length)
394
+
395
+ def attention_operation(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, padding_mask: Optional[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
396
+ """Standard attention computation with masking."""
397
+ batch_size, _, query_length, _ = query.size()
398
+ _, _, key_length, _ = key.size()
399
+
400
+ # Use cached window mask
401
+ with torch.no_grad():
402
+ window_mask = self._get_window_mask(query_length, key_length, query.device)
403
+ if padding_mask is not None:
404
+ attention_mask = padding_mask & window_mask
405
+ else:
406
+ attention_mask = window_mask
407
+
408
+ attention_scores = torch.bmm(query.flatten(0, 1), key.transpose(-1, -2).flatten(0, 1)) * self.scale # shape: [B*H, Q_T, K_T]
409
+ attention_scores = attention_scores.view(batch_size, self.num_attention_heads, query_length, key_length)
410
+
411
+ attention_probabilities = MaskedSoftmax.apply(attention_scores, ~attention_mask, -1)
412
+ attention_probabilities = self.attention_dropout(attention_probabilities)
413
+
414
+ output = torch.bmm(attention_probabilities.flatten(0, 1), value.flatten(0, 1))
415
+ output = output.view(batch_size, self.num_attention_heads, query_length, self.d_v)
416
+
417
+ return output
418
+
419
+ def forward(self, hidden_layer: torch.Tensor, qk_layer: torch.Tensor, v1: torch.Tensor | None, padding_info):
420
+ # Get original shape info
421
+ if flash_attn_varlen_qkvpacked_func is not None:
422
+ # Unpadded case
423
+ indices, cu_seqlens, max_seqlen = padding_info
424
+ total_seqlen = hidden_layer.size(0)
425
+ batch_size = cu_seqlens.size(0) - 1
426
+ else:
427
+ # Padded case
428
+ batch_size, seq_length = hidden_layer.size(0), hidden_layer.size(1)
429
+
430
+ hidden_layer = self.pre_v_norm(hidden_layer.float()).type_as(hidden_layer)
431
+ qk_layer = self.pre_qk_norm(qk_layer.float()).type_as(qk_layer)
432
+
433
+ query, key = self.qk_proj(qk_layer).tensor_split([self.q_out_dim], dim=-1)
434
+ value = self.v_proj(hidden_layer)
435
+
436
+ if flash_attn_varlen_qkvpacked_func is not None:
437
+ # Reshape for FlashAttention: (total_seqlen, num_heads, head_dim)
438
+ query = query.view(total_seqlen, self.num_attention_heads, self.d_qk)
439
+ key = key.view(total_seqlen, self.num_kv_heads, self.d_qk)
440
+ value = value.view(total_seqlen, self.num_kv_heads, self.d_v)
441
+
442
+ # Apply layer norm and scaling
443
+ query = ((self.q_scale + 1.0).unsqueeze(0) * self.q_norm(query.float())).type_as(query)
444
+ key = ((self.k_scale + 1.0).unsqueeze(0) * self.k_norm(key.float())).type_as(key)
445
+
446
+ if v1 is None:
447
+ v1 = value
448
+ value = (1 - self.lambdas[0]) * value + self.lambdas[0] * v1
449
+
450
+ # Prepare qkv for FlashAttention
451
+ qkv = torch.stack([query, key, value], dim=1) # (total_seqlen, 3, num_heads, head_dim)
452
+
453
+ # Determine window size for local attention
454
+ if self.window_length is not None and self.window_length > 0:
455
+ if self.is_causal:
456
+ local_attention = (self.window_length - 1, 0)
457
+ else:
458
+ local_attention = (self.window_length - 1, self.window_length - 1)
459
+ else:
460
+ local_attention = (-1, -1)
461
+
462
+ # Apply FlashAttention
463
+ output = flash_attention_forward(
464
+ qkv,
465
+ self.rope_embedding,
466
+ cu_seqlens,
467
+ max_seqlen,
468
+ self.is_causal,
469
+ local_attention,
470
+ self.config.attention_dropout if self.training else 0.0,
471
+ self.config.deterministic_flash_attn
472
+ )
473
+
474
+ # Reshape output back
475
+ output = output.view(total_seqlen, self.d_v * self.num_attention_heads)
476
+
477
+ else:
478
+ # Standard attention path
479
+ query_length = query.size(1)
480
+ key_length = key.size(1)
481
+
482
+ query = query.reshape(batch_size, query_length, self.num_attention_heads, self.d_qk).transpose(1, 2)
483
+ key = key.reshape(batch_size, key_length, self.num_kv_heads, self.d_qk).transpose(1, 2)
484
+ value = value.reshape(batch_size, key_length, self.num_kv_heads, self.d_v).transpose(1, 2)
485
+
486
+ query = ((self.q_scale + 1.0).unsqueeze(1).unsqueeze(0) * self.q_norm(query.float())).type_as(query)
487
+ key = ((self.k_scale + 1.0).unsqueeze(1).unsqueeze(0) * self.k_norm(key.float())).type_as(key)
488
+
489
+ if v1 is None:
490
+ v1 = value
491
+ else:
492
+ value = (1 - self.lambdas[0]) * value + self.lambdas[0] * v1
493
+
494
+ # Apply rotary embeddings
495
+ query = self.rope_embedding(query)
496
+ key = self.rope_embedding(key)
497
+
498
+ output = self.attention_operation(query, key, value, padding_info)
499
+ output = output.transpose(1, 2).flatten(2, 3) # shape: [B, T, H*D]
500
+
501
+ output = self.inter_norm(output.float()).type_as(output)
502
+ output = self.out_proj(output)
503
+ output = self.dropout(output)
504
+
505
+ return output, v1
506
+
507
+
508
+ class FeedForward(nn.Module):
509
+ def __init__(self, config: GptBertConfig):
510
+ super().__init__()
511
+ self.pre_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False)
512
+ self.up_proj = MultiCastedLinearOrthoIn(config.hidden_size, [config.intermediate_size, config.intermediate_size], bias=False)
513
+ self.activation = GeGLU()
514
+ self.inter_norm = nn.LayerNorm(config.intermediate_size, eps=config.layer_norm_eps, elementwise_affine=False)
515
+ self.down_proj = CastedLinearIn(config.intermediate_size, config.hidden_size, bias=False)
516
+ self.dropout = nn.Dropout(config.hidden_dropout)
517
+
518
+ def forward(self, x: torch.Tensor):
519
+ x = self.pre_norm(x.float()).type_as(x)
520
+ x = self.up_proj(x)
521
+ x = self.activation(x)
522
+ x = self.inter_norm(x.float()).type_as(x)
523
+ x = self.down_proj(x)
524
+ x = self.dropout(x)
525
+ return x
526
+
527
+
528
+ class Layer(nn.Module):
529
+ def __init__(self, config: GptBertConfig, layer_idx: int):
530
+ super().__init__()
531
+
532
+ self.attention = SelfAttention(config, layer_idx)
533
+ self.mlp = FeedForward(config)
534
+ self.lambdas = nn.Parameter(torch.tensor([0., 0., 1., 0., 1., 0.]))
535
+
536
+ def set_window_length(self, window_length: int):
537
+ self.attention.set_window_length(window_length)
538
+
539
+ def forward(self, hidden_layer: torch.Tensor, embeddings: torch.Tensor, v1: torch.Tensor | None, padding_info):
540
+ attention_output = (1 - self.lambdas[0]) * hidden_layer + self.lambdas[0] * embeddings
541
+ qk_layer = (1 - self.lambdas[1]) * hidden_layer + self.lambdas[1] * embeddings
542
+ mlp_layer = F.softplus(self.lambdas[2]) * ((1 - self.lambdas[3]) * hidden_layer + self.lambdas[3] * embeddings)
543
+
544
+ attention_output, v1 = self.attention(attention_output, qk_layer, v1, padding_info)
545
+ mlp_layer = mlp_layer + attention_output
546
+ hidden_layer = F.softplus(self.lambdas[4]) * ((1 - self.lambdas[5]) * hidden_layer + self.lambdas[5] * embeddings)
547
+ output = hidden_layer + attention_output + self.mlp(mlp_layer)
548
+
549
+ return output, v1
550
+
551
+
552
+ class Encoder(nn.Module):
553
+ def __init__(self, config: GptBertConfig):
554
+ super().__init__()
555
+ self.layers = nn.ModuleList([Layer(config, i) for i in range(config.num_layers)])
556
+ self.local_global_ratio = config.local_global_ratio
557
+
558
+ def set_window_length(self, config: GptBertConfig):
559
+ for i, layer in enumerate(self.layers):
560
+ if (i + 1) % self.local_global_ratio == 0:
561
+ layer.set_window_length(config.global_window_length)
562
+ else:
563
+ layer.set_window_length(config.local_window_length)
564
+
565
+ def forward(self, hidden_layer: torch.Tensor, padding_info, output_hidden_states=False, checkpoint_activations=False):
566
+ hidden_layers = [hidden_layer] if output_hidden_states else None
567
+ v1 = None
568
+ embeddings = hidden_layer
569
+
570
+ for layer in self.layers:
571
+ if checkpoint_activations:
572
+ hidden_layer, v1 = torch.utils.checkpoint.checkpoint(layer, hidden_layer, embeddings, v1, padding_info, use_reentrant=True)
573
+ else:
574
+ hidden_layer, v1 = layer(hidden_layer, embeddings, v1, padding_info)
575
+
576
+ if output_hidden_states:
577
+ hidden_layers.append(hidden_layer)
578
+
579
+ return hidden_layer, hidden_layers
580
+
581
+
582
+ #
583
+ # HuggingFace wrappers
584
+ #
585
+
586
+ class GptBertPreTrainedModel(PreTrainedModel):
587
+ config_class = GptBertConfig
588
+ supports_gradient_checkpointing = True
589
+ _supports_flash_attn_2 = True
590
+ _supports_sdpa = True
591
+ _supports_flex_attn = False
592
+
593
+ def _init_weights(self, module):
594
+ std = math.sqrt(2.0 / (5.0 * self.hidden_size))
595
+
596
+ if isinstance(module, nn.Linear) or isinstance(module, CastedLinearIn):
597
+ nn.init.trunc_normal_(module.weight.data, mean=0.0, std=std, a=-2*std, b=2*std)
598
+ if module.bias is not None:
599
+ module.bias.data.zero_()
600
+ elif isinstance(module, nn.Embedding):
601
+ nn.init.trunc_normal_(module.weight.data, mean=0.0, std=std, a=-2*std, b=2*std)
602
+ elif isinstance(module, nn.LayerNorm):
603
+ module.bias.data.zero_()
604
+ module.weight.data.fill_(1.0)
605
+
606
+
607
+ class GptBertModel(GptBertPreTrainedModel):
608
+ def __init__(self, config: GptBertConfig, add_mlm_layer=False, **kwargs):
609
+ super().__init__(config, **kwargs)
610
+ self.config = config
611
+ self.hidden_size = config.hidden_size
612
+
613
+ self.embedding = Embedding(config)
614
+ self.encoder = Encoder(config)
615
+ self.classifier = LMClassifier(config, config.vocab_size) if add_mlm_layer else None
616
+ self.set_window_length(config)
617
+ self.gradient_checkpointing = False
618
+ self.post_init()
619
+
620
+ def set_window_length(self, config) -> None:
621
+ self.encoder.set_window_length(config)
622
+
623
+ def get_input_embeddings(self):
624
+ return self.embedding.word_embedding
625
+
626
+ def set_input_embeddings(self, value):
627
+ self.embedding.word_embedding = value
628
+
629
+ def get_contextualized_embeddings(
630
+ self,
631
+ input_ids: Optional[torch.Tensor] = None,
632
+ attention_mask: Optional[torch.Tensor] = None,
633
+ output_hidden_states: Optional[bool] = None
634
+ ):
635
+ if input_ids is not None:
636
+ input_shape = input_ids.size()
637
+ else:
638
+ raise ValueError("You have to specify input_ids")
639
+
640
+ batch_size, seq_length = input_shape
641
+ device = input_ids.device
642
+
643
+ if attention_mask is None:
644
+ attention_mask = torch.ones(batch_size, seq_length, dtype=torch.bool, device=device)
645
+ else:
646
+ attention_mask = attention_mask.bool()
647
+
648
+ if flash_attn_varlen_qkvpacked_func is not None:
649
+ if len(attention_mask.size()) != 2:
650
+ raise ValueError("Bare `attention_mask` med to dimensjoner støttes nå for FlashAttention.")
651
+ with torch.no_grad():
652
+ input_ids, indices, cu_seqlens, max_seqlen_in_batch = _unpad_input(input_ids, attention_mask)
653
+ padding_info = (indices, cu_seqlens, max_seqlen_in_batch)
654
+ else:
655
+ if len(attention_mask.size()) == 2:
656
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
657
+ elif len(attention_mask.size()) == 3:
658
+ attention_mask = attention_mask.unsqueeze(1)
659
+ padding_info = attention_mask
660
+
661
+ static_embeddings = self.embedding(input_ids)
662
+
663
+ original_dtype = static_embeddings.dtype
664
+ if torch.cuda.is_available() and torch.cuda.is_bf16_supported() and static_embeddings.dtype == torch.float32:
665
+ static_embeddings = static_embeddings.bfloat16()
666
+
667
+ last_layer, contextualized_embeddings = self.encoder(
668
+ static_embeddings,
669
+ padding_info,
670
+ output_hidden_states=output_hidden_states,
671
+ checkpoint_activations=self.gradient_checkpointing and self.training
672
+ )
673
+
674
+ last_layer = last_layer.to(original_dtype)
675
+ if output_hidden_states:
676
+ contextualized_embeddings = [layer.to(original_dtype) for layer in contextualized_embeddings]
677
+
678
+ # Pad output if using FlashAttention
679
+ if flash_attn_varlen_qkvpacked_func is not None:
680
+ last_layer = _pad_output(last_layer, indices, batch_size, seq_length)
681
+ if output_hidden_states:
682
+ contextualized_embeddings = [_pad_output(layer, indices, batch_size, seq_length) for layer in contextualized_embeddings]
683
+ else:
684
+ contextualized_embeddings = None
685
+
686
+ return last_layer, contextualized_embeddings
687
+
688
+ def forward(
689
+ self,
690
+ input_ids: Optional[torch.Tensor] = None,
691
+ attention_mask: Optional[torch.Tensor] = None,
692
+ output_hidden_states: Optional[bool] = None,
693
+ output_attentions: Optional[bool] = None,
694
+ return_dict: Optional[bool] = None,
695
+ **kwargs
696
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutput]:
697
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
698
+
699
+ sequence_output, contextualized_embeddings = self.get_contextualized_embeddings(input_ids, attention_mask, output_hidden_states)
700
+
701
+ if not return_dict:
702
+ return (
703
+ sequence_output,
704
+ *([contextualized_embeddings] if output_hidden_states else [])
705
+ )
706
+
707
+ return BaseModelOutput(
708
+ last_hidden_state=sequence_output,
709
+ hidden_states=contextualized_embeddings if output_hidden_states else None
710
+ )
711
+
712
+
713
+ class GptBertForMaskedLM(GptBertModel):
714
+ _tied_weights_keys = ["classifier.emb2vocab.weight"]
715
+
716
+ def __init__(self, config: GptBertConfig, **kwargs):
717
+ super().__init__(config, add_mlm_layer=True, **kwargs)
718
+
719
+ def get_output_embeddings(self):
720
+ return self.classifier.emb2vocab.weight
721
+
722
+ def set_output_embeddings(self, new_embeddings):
723
+ self.classifier.emb2vocab.weight = new_embeddings
724
+
725
+ def forward(
726
+ self,
727
+ input_ids: Optional[torch.Tensor] = None,
728
+ attention_mask: Optional[torch.Tensor] = None,
729
+ output_hidden_states: Optional[bool] = None,
730
+ return_dict: Optional[bool] = None,
731
+ labels: Optional[torch.LongTensor] = None,
732
+ **kwargs
733
+ ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
734
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
735
+
736
+ sequence_output, contextualized_embeddings = self.get_contextualized_embeddings(input_ids, attention_mask, output_hidden_states)
737
+ subword_prediction = self.classifier(sequence_output)
738
+ subword_prediction = 30 * torch.sigmoid(subword_prediction / 7.5)
739
+
740
+ masked_lm_loss = None
741
+ if labels is not None:
742
+ labels_flatten = labels[:, 1:].flatten()
743
+ subword_prediction_flatten = subword_prediction[:, :-1].flatten(0, 1)
744
+ masked_lm_loss = F.cross_entropy(subword_prediction_flatten, labels_flatten)
745
+
746
+ bos_logits = torch.zeros(subword_prediction.size(0), 1, self.config.vocab_size, dtype=subword_prediction.dtype, device=subword_prediction.device)
747
+ bos_logits[:, :, self.config.bos_token_id] = 1.0
748
+ subword_prediction = torch.cat([bos_logits, subword_prediction[:, :-1]], dim=1)
749
+
750
+ if not return_dict:
751
+ output = (
752
+ subword_prediction,
753
+ *([contextualized_embeddings] if output_hidden_states else [])
754
+ )
755
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
756
+
757
+ return MaskedLMOutput(
758
+ loss=masked_lm_loss,
759
+ logits=subword_prediction,
760
+ hidden_states=contextualized_embeddings if output_hidden_states else None
761
+ )
762
+
763
+
764
+ class GptBertForCausalLM(GptBertModel):
765
+ _tied_weights_keys = ["classifier.emb2vocab.weight"]
766
+
767
+ def __init__(self, config: GptBertConfig, **kwargs):
768
+ config.is_decoder = True
769
+ super().__init__(config, add_mlm_layer=True, **kwargs)
770
+
771
+ def get_output_embeddings(self):
772
+ return self.classifier.emb2vocab.weight
773
+
774
+ def set_output_embeddings(self, new_embeddings):
775
+ self.classifier.emb2vocab.weight = new_embeddings
776
+
777
+ def get_input_embeddings(self):
778
+ return self.embedding.word_embedding
779
+
780
+ def set_input_embeddings(self, value):
781
+ self.embedding.word_embedding = value
782
+
783
+ def set_decoder(self, decoder):
784
+ self.encoder = decoder
785
+
786
+ def get_decoder(self):
787
+ return self.encoder
788
+
789
+ def can_generate(self):
790
+ return True
791
+
792
+ def forward(
793
+ self,
794
+ input_ids: torch.LongTensor = None,
795
+ attention_mask: Optional[torch.Tensor] = None,
796
+ position_ids: Optional[torch.LongTensor] = None,
797
+ token_type_ids: Optional[torch.Tensor] = None,
798
+ past_key_values: Optional[torch.Tensor] = None,
799
+ inputs_embeds: Optional[torch.FloatTensor] = None,
800
+ labels: Optional[torch.LongTensor] = None,
801
+ use_cache: Optional[bool] = None,
802
+ cache_position: Optional[torch.LongTensor] = None,
803
+ output_attentions: Optional[bool] = None,
804
+ output_hidden_states: Optional[bool] = None,
805
+ return_dict: Optional[bool] = None
806
+ ) -> Union[Tuple, CausalLMOutput]:
807
+
808
+ assert inputs_embeds is None, "inputs_embeds is not supported for now"
809
+ assert past_key_values is None, "past_key_values is not supported for now"
810
+ assert not use_cache, "use_cache is not supported for now"
811
+
812
+ sequence_output, contextualized_embeddings = self.get_contextualized_embeddings(input_ids, attention_mask, output_hidden_states)
813
+ subword_prediction = self.classifier(sequence_output)
814
+ subword_prediction = 30 * torch.sigmoid(subword_prediction / 7.5)
815
+
816
+ causal_lm_loss = None
817
+ if labels is not None:
818
+ labels_flatten = labels[:, 1:].flatten()
819
+ subword_prediction_flatten = subword_prediction[:, :-1].flatten(0, 1)
820
+ causal_lm_loss = F.cross_entropy(subword_prediction_flatten, labels_flatten)
821
+
822
+ if not return_dict:
823
+ output = (
824
+ subword_prediction,
825
+ *([contextualized_embeddings] if output_hidden_states else [])
826
+ )
827
+ return ((causal_lm_loss,) + output) if masked_lm_loss is not None else output
828
+
829
+ return CausalLMOutput(
830
+ loss=causal_lm_loss,
831
+ logits=subword_prediction,
832
+ hidden_states=contextualized_embeddings if output_hidden_states else None
833
+ )
834
+
835
+ def prepare_inputs_for_generation(
836
+ self,
837
+ input_ids: torch.Tensor,
838
+ past_key_values: Optional[torch.Tensor] = None,
839
+ attention_mask: Optional[torch.Tensor] = None,
840
+ inputs_embeds: Optional[torch.Tensor] = None,
841
+ cache_position: Optional[torch.LongTensor] = None,
842
+ position_ids: Optional[torch.LongTensor] = None,
843
+ use_cache: bool = True,
844
+ num_logits_to_keep: Optional[int] = None,
845
+ **kwargs,
846
+ ):
847
+ # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
848
+ # Exception 1: when passing input_embeds, input_ids may be missing entries
849
+ # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
850
+ if past_key_values is not None:
851
+ if inputs_embeds is not None: # Exception 1
852
+ input_ids = input_ids[:, -cache_position.shape[0] :]
853
+ elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
854
+ input_ids = input_ids[:, cache_position]
855
+
856
+ if attention_mask is not None and position_ids is None:
857
+ # create position_ids on the fly for batch generation
858
+ position_ids = attention_mask.long().cumsum(-1) - 1
859
+ position_ids.masked_fill_(attention_mask == 0, 1)
860
+ if past_key_values:
861
+ position_ids = position_ids[:, -input_ids.shape[1] :]
862
+
863
+ # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
864
+ position_ids = position_ids.clone(memory_format=torch.contiguous_format)
865
+
866
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
867
+ if inputs_embeds is not None and cache_position[0] == 0:
868
+ model_inputs = {"inputs_embeds": inputs_embeds}
869
+ else:
870
+ model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
871
+
872
+ if num_logits_to_keep is not None:
873
+ model_inputs["num_logits_to_keep"] = num_logits_to_keep
874
+
875
+ model_inputs.update(
876
+ {
877
+ "position_ids": position_ids,
878
+ "cache_position": cache_position,
879
+ "past_key_values": past_key_values,
880
+ "use_cache": use_cache,
881
+ "attention_mask": attention_mask,
882
+ }
883
+ )
884
+ return model_inputs
885
+
886
+
887
+ class GptBertForSequenceClassification(GptBertModel):
888
+ _keys_to_ignore_on_load_missing = ["classifier.emb2vocab.weight", "classifier.emb2vocab.bias"]
889
+ _keys_to_ignore_on_load_unexpected = ["classifier.emb2vocab.weight", "classifier.emb2vocab.bias"]
890
+
891
+ def __init__(self, config: GptBertConfig, **kwargs):
892
+ super().__init__(config, add_mlm_layer=False, **kwargs)
893
+
894
+ self.num_labels = config.num_labels
895
+ self.classifier = Classifier(config, self.num_labels)
896
+ self.post_init()
897
+
898
+ def forward(
899
+ self,
900
+ input_ids: Optional[torch.Tensor] = None,
901
+ attention_mask: Optional[torch.Tensor] = None,
902
+ output_hidden_states: Optional[bool] = None,
903
+ return_dict: Optional[bool] = None,
904
+ labels: Optional[torch.LongTensor] = None,
905
+ **kwargs
906
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
907
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
908
+
909
+ sequence_output, contextualized_embeddings = self.get_contextualized_embeddings(input_ids, attention_mask, output_hidden_states)
910
+ logits = self.classifier(sequence_output[:, 0, :])
911
+
912
+ loss = None
913
+ if labels is not None:
914
+ if self.config.problem_type is None:
915
+ if self.num_labels == 1:
916
+ self.config.problem_type = "regression"
917
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
918
+ self.config.problem_type = "single_label_classification"
919
+ else:
920
+ self.config.problem_type = "multi_label_classification"
921
+
922
+ if self.config.problem_type == "regression":
923
+ loss_fct = nn.MSELoss()
924
+ if self.num_labels == 1:
925
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
926
+ else:
927
+ loss = loss_fct(logits, labels)
928
+ elif self.config.problem_type == "single_label_classification":
929
+ loss_fct = nn.CrossEntropyLoss()
930
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
931
+ elif self.config.problem_type == "multi_label_classification":
932
+ loss_fct = nn.BCEWithLogitsLoss()
933
+ loss = loss_fct(logits, labels)
934
+
935
+ if not return_dict:
936
+ output = (
937
+ logits,
938
+ *([contextualized_embeddings] if output_hidden_states else [])
939
+ )
940
+ return ((loss,) + output) if loss is not None else output
941
+
942
+ return SequenceClassifierOutput(
943
+ loss=loss,
944
+ logits=logits,
945
+ hidden_states=contextualized_embeddings if output_hidden_states else None
946
+ )
947
+
948
+
949
+ class GptBertForTokenClassification(GptBertModel):
950
+ _keys_to_ignore_on_load_missing = ["classifier.emb2vocab.weight", "classifier.emb2vocab.bias"]
951
+ _keys_to_ignore_on_load_unexpected = ["classifier.emb2vocab.weight", "classifier.emb2vocab.bias"]
952
+
953
+ def __init__(self, config: GptBertConfig, **kwargs):
954
+ super().__init__(config, add_mlm_layer=False, **kwargs)
955
+
956
+ self.num_labels = config.num_labels
957
+ self.classifier = Classifier(config, self.num_labels)
958
+ self.post_init()
959
+
960
+ def forward(
961
+ self,
962
+ input_ids: Optional[torch.Tensor] = None,
963
+ attention_mask: Optional[torch.Tensor] = None,
964
+ output_hidden_states: Optional[bool] = None,
965
+ return_dict: Optional[bool] = None,
966
+ labels: Optional[torch.LongTensor] = None,
967
+ **kwargs
968
+ ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
969
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
970
+
971
+ sequence_output, contextualized_embeddings = self.get_contextualized_embeddings(input_ids, attention_mask, output_hidden_states)
972
+ logits = self.classifier(sequence_output)
973
+
974
+ loss = None
975
+ if labels is not None:
976
+ loss_fct = nn.CrossEntropyLoss()
977
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
978
+
979
+ if not return_dict:
980
+ output = (
981
+ logits,
982
+ *([contextualized_embeddings] if output_hidden_states else []),
983
+ *([attention_probs] if output_attentions else [])
984
+ )
985
+ return ((loss,) + output) if loss is not None else output
986
+
987
+ return TokenClassifierOutput(
988
+ loss=loss,
989
+ logits=logits,
990
+ hidden_states=contextualized_embeddings if output_hidden_states else None,
991
+ attentions=attention_probs if output_attentions else None
992
+ )
993
+
994
+
995
+ class GptBertForQuestionAnswering(GptBertModel):
996
+ _keys_to_ignore_on_load_missing = ["classifier.emb2vocab.weight", "classifier.emb2vocab.bias"]
997
+ _keys_to_ignore_on_load_unexpected = ["classifier.emb2vocab.weight", "classifier.emb2vocab.bias"]
998
+
999
+ def __init__(self, config: GptBertConfig, **kwargs):
1000
+ super().__init__(config, add_mlm_layer=False, **kwargs)
1001
+
1002
+ self.num_labels = config.num_labels
1003
+ self.classifier = Classifier(config, self.num_labels)
1004
+ self.post_init()
1005
+
1006
+ def forward(
1007
+ self,
1008
+ input_ids: Optional[torch.Tensor] = None,
1009
+ attention_mask: Optional[torch.Tensor] = None,
1010
+ output_hidden_states: Optional[bool] = None,
1011
+ return_dict: Optional[bool] = None,
1012
+ start_positions: Optional[torch.Tensor] = None,
1013
+ end_positions: Optional[torch.Tensor] = None,
1014
+ **kwargs
1015
+ ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
1016
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1017
+
1018
+ sequence_output, contextualized_embeddings = self.get_contextualized_embeddings(input_ids, attention_mask, output_hidden_states)
1019
+ logits = self.classifier(sequence_output)
1020
+
1021
+ start_logits, end_logits = logits.split(1, dim=-1)
1022
+ start_logits = start_logits.squeeze(-1).contiguous()
1023
+ end_logits = end_logits.squeeze(-1).contiguous()
1024
+
1025
+ total_loss = None
1026
+ if start_positions is not None and end_positions is not None:
1027
+ # If we are on multi-GPU, split add a dimension
1028
+ if len(start_positions.size()) > 1:
1029
+ start_positions = start_positions.squeeze(-1)
1030
+ if len(end_positions.size()) > 1:
1031
+ end_positions = end_positions.squeeze(-1)
1032
+
1033
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1034
+ ignored_index = start_logits.size(1)
1035
+ start_positions = start_positions.clamp(0, ignored_index)
1036
+ end_positions = end_positions.clamp(0, ignored_index)
1037
+
1038
+ loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index)
1039
+ start_loss = loss_fct(start_logits, start_positions)
1040
+ end_loss = loss_fct(end_logits, end_positions)
1041
+ total_loss = (start_loss + end_loss) / 2
1042
+
1043
+ if not return_dict:
1044
+ output = (
1045
+ start_logits,
1046
+ end_logits,
1047
+ *([contextualized_embeddings] if output_hidden_states else [])
1048
+ )
1049
+ return ((total_loss,) + output) if total_loss is not None else output
1050
+
1051
+ return QuestionAnsweringModelOutput(
1052
+ loss=total_loss,
1053
+ start_logits=start_logits,
1054
+ end_logits=end_logits,
1055
+ hidden_states=contextualized_embeddings if output_hidden_states else None
1056
+ )
1057
+
1058
+
1059
+ class GptBertForMultipleChoice(GptBertModel):
1060
+ _keys_to_ignore_on_load_missing = ["classifier.emb2vocab.weight", "classifier.emb2vocab.bias"]
1061
+ _keys_to_ignore_on_load_unexpected = ["classifier.emb2vocab.weight", "classifier.emb2vocab.bias"]
1062
+
1063
+ def __init__(self, config: GptBertConfig, **kwargs):
1064
+ super().__init__(config, add_mlm_layer=False, **kwargs)
1065
+
1066
+ self.num_labels = getattr(config, "num_labels", 2)
1067
+ self.classifier = Classifier(config, self.num_labels)
1068
+ self.post_init()
1069
+
1070
+ def forward(
1071
+ self,
1072
+ input_ids: Optional[torch.Tensor] = None,
1073
+ attention_mask: Optional[torch.Tensor] = None,
1074
+ labels: Optional[torch.Tensor] = None,
1075
+ output_hidden_states: Optional[bool] = None,
1076
+ return_dict: Optional[bool] = None,
1077
+ **kwargs
1078
+ ) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]:
1079
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1080
+ num_choices = input_ids.shape[1]
1081
+
1082
+ flat_input_ids = input_ids.view(-1, input_ids.size(-1))
1083
+ flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
1084
+
1085
+ sequence_output, contextualized_embeddings = self.get_contextualized_embeddings(flat_input_ids, flat_attention_mask, output_hidden_states)
1086
+ logits = self.classifier(sequence_output)
1087
+ reshaped_logits = logits.view(-1, num_choices)
1088
+
1089
+ loss = None
1090
+ if labels is not None:
1091
+ loss_fct = nn.CrossEntropyLoss()
1092
+ loss = loss_fct(reshaped_logits, labels)
1093
+
1094
+ if not return_dict:
1095
+ output = (
1096
+ reshaped_logits,
1097
+ *([contextualized_embeddings] if output_hidden_states else [])
1098
+ )
1099
+ return ((loss,) + output) if loss is not None else output
1100
+
1101
+ return MultipleChoiceModelOutput(
1102
+ loss=loss,
1103
+ logits=reshaped_logits,
1104
+ hidden_states=contextualized_embeddings if output_hidden_states else None
1105
+ )
last-checkpoint/modules.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "idx": 0,
4
+ "name": "0",
5
+ "path": "",
6
+ "type": "sentence_transformers.sparse_encoder.models.MLMTransformer"
7
+ },
8
+ {
9
+ "idx": 1,
10
+ "name": "1",
11
+ "path": "1_SpladePooling",
12
+ "type": "sentence_transformers.sparse_encoder.models.SpladePooling"
13
+ }
14
+ ]
last-checkpoint/optimizer.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6ee91d450dde2beb96e8c2398912baa13f858a3b4f7cee07b28a9b96f3e588ef
3
+ size 1457369077
last-checkpoint/rng_state_0.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4cdc784e3b91bc23bce54961fdaef58e6442cd03f625edf44e230178fd37f8fa
3
+ size 14917
last-checkpoint/rng_state_1.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a36844f32afb06c561965a6f6eb81809058336e154b9d7e2fb6b83900a7ad0fa
3
+ size 14917
last-checkpoint/scheduler.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d959e23fadc9c5a5f14d7b9c3a56d1fd374b1bcf4b39d4b142f83de164ff2685
3
+ size 1465
last-checkpoint/sentence_bert_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "max_seq_length": null,
3
+ "do_lower_case": false
4
+ }
last-checkpoint/special_tokens_map.json ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "cls_token": {
10
+ "content": "<s>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "eos_token": {
17
+ "content": "</s>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "mask_token": {
24
+ "content": "<mask>",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ },
30
+ "pad_token": {
31
+ "content": "<pad>",
32
+ "lstrip": false,
33
+ "normalized": false,
34
+ "rstrip": false,
35
+ "single_word": false
36
+ },
37
+ "sep_token": {
38
+ "content": "</s>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false
43
+ },
44
+ "unk_token": {
45
+ "content": "<unk>",
46
+ "lstrip": false,
47
+ "normalized": false,
48
+ "rstrip": false,
49
+ "single_word": false
50
+ }
51
+ }
last-checkpoint/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
last-checkpoint/tokenizer_config.json ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "<unk>",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "1": {
12
+ "content": "<s>",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "2": {
20
+ "content": "</s>",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "3": {
28
+ "content": "<pad>",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "4": {
36
+ "content": "<mask>",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ },
43
+ "5": {
44
+ "content": "<special_0>",
45
+ "lstrip": false,
46
+ "normalized": false,
47
+ "rstrip": false,
48
+ "single_word": false,
49
+ "special": true
50
+ },
51
+ "6": {
52
+ "content": "<special_1>",
53
+ "lstrip": false,
54
+ "normalized": false,
55
+ "rstrip": false,
56
+ "single_word": false,
57
+ "special": true
58
+ },
59
+ "7": {
60
+ "content": "<special_2>",
61
+ "lstrip": false,
62
+ "normalized": false,
63
+ "rstrip": false,
64
+ "single_word": false,
65
+ "special": true
66
+ },
67
+ "8": {
68
+ "content": "<special_3>",
69
+ "lstrip": false,
70
+ "normalized": false,
71
+ "rstrip": false,
72
+ "single_word": false,
73
+ "special": true
74
+ },
75
+ "9": {
76
+ "content": "<special_4>",
77
+ "lstrip": false,
78
+ "normalized": false,
79
+ "rstrip": false,
80
+ "single_word": false,
81
+ "special": true
82
+ },
83
+ "10": {
84
+ "content": "<special_5>",
85
+ "lstrip": false,
86
+ "normalized": false,
87
+ "rstrip": false,
88
+ "single_word": false,
89
+ "special": true
90
+ },
91
+ "11": {
92
+ "content": "<special_6>",
93
+ "lstrip": false,
94
+ "normalized": false,
95
+ "rstrip": false,
96
+ "single_word": false,
97
+ "special": true
98
+ },
99
+ "12": {
100
+ "content": "<special_7>",
101
+ "lstrip": false,
102
+ "normalized": false,
103
+ "rstrip": false,
104
+ "single_word": false,
105
+ "special": true
106
+ },
107
+ "13": {
108
+ "content": "<special_8>",
109
+ "lstrip": false,
110
+ "normalized": false,
111
+ "rstrip": false,
112
+ "single_word": false,
113
+ "special": true
114
+ },
115
+ "14": {
116
+ "content": "<special_9>",
117
+ "lstrip": false,
118
+ "normalized": false,
119
+ "rstrip": false,
120
+ "single_word": false,
121
+ "special": true
122
+ },
123
+ "15": {
124
+ "content": "<special_10>",
125
+ "lstrip": false,
126
+ "normalized": false,
127
+ "rstrip": false,
128
+ "single_word": false,
129
+ "special": true
130
+ }
131
+ },
132
+ "bos_token": "<s>",
133
+ "clean_up_tokenization_spaces": false,
134
+ "cls_token": "<s>",
135
+ "eos_token": "</s>",
136
+ "extra_special_tokens": {},
137
+ "mask_token": "<mask>",
138
+ "model_max_length": 4096,
139
+ "pad_token": "<pad>",
140
+ "sep_token": "</s>",
141
+ "tokenizer_class": "PreTrainedTokenizerFast",
142
+ "unk_token": "<unk>"
143
+ }
last-checkpoint/trainer_state.json ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "best_global_step": 500,
3
+ "best_metric": 0.027178706104522946,
4
+ "best_model_checkpoint": "models/splade-norbert4-base-retrieval-only/checkpoint-500",
5
+ "epoch": 0.04797083373309028,
6
+ "eval_steps": 500,
7
+ "global_step": 500,
8
+ "is_hyper_param_search": false,
9
+ "is_local_process_zero": true,
10
+ "is_world_process_zero": true,
11
+ "log_history": [
12
+ {
13
+ "base_loss": 37895.6328,
14
+ "document_regularizer_loss": 0.0598,
15
+ "epoch": 0.004797083373309028,
16
+ "grad_norm": 259787.75,
17
+ "learning_rate": 9.395973154362417e-07,
18
+ "loss": 37895.69,
19
+ "query_regularizer_loss": 0.001,
20
+ "step": 50
21
+ },
22
+ {
23
+ "base_loss": 10001.6025,
24
+ "document_regularizer_loss": 0.4482,
25
+ "epoch": 0.009594166746618057,
26
+ "grad_norm": 505128.8125,
27
+ "learning_rate": 1.8983700862895495e-06,
28
+ "loss": 10002.0562,
29
+ "query_regularizer_loss": 0.0037,
30
+ "step": 100
31
+ },
32
+ {
33
+ "base_loss": 3804.3779,
34
+ "document_regularizer_loss": 1.0922,
35
+ "epoch": 0.014391250119927085,
36
+ "grad_norm": 53786.4296875,
37
+ "learning_rate": 2.8571428571428573e-06,
38
+ "loss": 3805.4731,
39
+ "query_regularizer_loss": 0.0023,
40
+ "step": 150
41
+ },
42
+ {
43
+ "base_loss": 921.3414,
44
+ "document_regularizer_loss": 1.7523,
45
+ "epoch": 0.019188333493236114,
46
+ "grad_norm": 48469.69140625,
47
+ "learning_rate": 3.815915627996165e-06,
48
+ "loss": 923.0944,
49
+ "query_regularizer_loss": 0.0007,
50
+ "step": 200
51
+ },
52
+ {
53
+ "base_loss": 512.5709,
54
+ "document_regularizer_loss": 2.2081,
55
+ "epoch": 0.02398541686654514,
56
+ "grad_norm": 9211.7822265625,
57
+ "learning_rate": 4.774688398849473e-06,
58
+ "loss": 514.7795,
59
+ "query_regularizer_loss": 0.0005,
60
+ "step": 250
61
+ },
62
+ {
63
+ "base_loss": 282.497,
64
+ "document_regularizer_loss": 2.0475,
65
+ "epoch": 0.02878250023985417,
66
+ "grad_norm": 430.40460205078125,
67
+ "learning_rate": 5.733461169702781e-06,
68
+ "loss": 284.5449,
69
+ "query_regularizer_loss": 0.0003,
70
+ "step": 300
71
+ },
72
+ {
73
+ "base_loss": 88.6157,
74
+ "document_regularizer_loss": 1.4521,
75
+ "epoch": 0.0335795836131632,
76
+ "grad_norm": 660.3614501953125,
77
+ "learning_rate": 6.692233940556089e-06,
78
+ "loss": 90.0678,
79
+ "query_regularizer_loss": 0.0001,
80
+ "step": 350
81
+ },
82
+ {
83
+ "base_loss": 30.6636,
84
+ "document_regularizer_loss": 0.1846,
85
+ "epoch": 0.03837666698647223,
86
+ "grad_norm": 3.8934402465820312,
87
+ "learning_rate": 7.651006711409396e-06,
88
+ "loss": 30.8482,
89
+ "query_regularizer_loss": 0.0,
90
+ "step": 400
91
+ },
92
+ {
93
+ "base_loss": 2.5066,
94
+ "document_regularizer_loss": 0.0005,
95
+ "epoch": 0.04317375035978125,
96
+ "grad_norm": 26.094982147216797,
97
+ "learning_rate": 8.609779482262704e-06,
98
+ "loss": 2.5071,
99
+ "query_regularizer_loss": 0.0,
100
+ "step": 450
101
+ },
102
+ {
103
+ "base_loss": 1.3518,
104
+ "document_regularizer_loss": 0.0007,
105
+ "epoch": 0.04797083373309028,
106
+ "grad_norm": 20.548200607299805,
107
+ "learning_rate": 9.568552253116012e-06,
108
+ "loss": 1.3525,
109
+ "query_regularizer_loss": 0.0,
110
+ "step": 500
111
+ },
112
+ {
113
+ "epoch": 0.04797083373309028,
114
+ "eval_NanoBEIR_mean_avg_flops": 51200.0,
115
+ "eval_NanoBEIR_mean_corpus_active_dims": 51200.0,
116
+ "eval_NanoBEIR_mean_corpus_sparsity_ratio": 0.0,
117
+ "eval_NanoBEIR_mean_dot_accuracy@1": 0.02,
118
+ "eval_NanoBEIR_mean_dot_accuracy@10": 0.12,
119
+ "eval_NanoBEIR_mean_dot_accuracy@3": 0.08,
120
+ "eval_NanoBEIR_mean_dot_accuracy@5": 0.08,
121
+ "eval_NanoBEIR_mean_dot_map@100": 0.006747512755501429,
122
+ "eval_NanoBEIR_mean_dot_mrr@10": 0.05088888888888889,
123
+ "eval_NanoBEIR_mean_dot_ndcg@10": 0.027178706104522946,
124
+ "eval_NanoBEIR_mean_dot_precision@1": 0.02,
125
+ "eval_NanoBEIR_mean_dot_precision@10": 0.026000000000000006,
126
+ "eval_NanoBEIR_mean_dot_precision@3": 0.03333333333333333,
127
+ "eval_NanoBEIR_mean_dot_precision@5": 0.032,
128
+ "eval_NanoBEIR_mean_dot_recall@1": 7.905138339920947e-05,
129
+ "eval_NanoBEIR_mean_dot_recall@10": 0.006349071275176555,
130
+ "eval_NanoBEIR_mean_dot_recall@3": 0.003312410422185988,
131
+ "eval_NanoBEIR_mean_dot_recall@5": 0.004545769460972766,
132
+ "eval_NanoBEIR_mean_query_active_dims": 51200.0,
133
+ "eval_NanoBEIR_mean_query_sparsity_ratio": 0.0,
134
+ "eval_NanoNFCorpus_avg_flops": 51200.0,
135
+ "eval_NanoNFCorpus_corpus_active_dims": 51200.0,
136
+ "eval_NanoNFCorpus_corpus_sparsity_ratio": 0.0,
137
+ "eval_NanoNFCorpus_dot_accuracy@1": 0.02,
138
+ "eval_NanoNFCorpus_dot_accuracy@10": 0.12,
139
+ "eval_NanoNFCorpus_dot_accuracy@3": 0.08,
140
+ "eval_NanoNFCorpus_dot_accuracy@5": 0.08,
141
+ "eval_NanoNFCorpus_dot_map@100": 0.006747512755501429,
142
+ "eval_NanoNFCorpus_dot_mrr@10": 0.05088888888888889,
143
+ "eval_NanoNFCorpus_dot_ndcg@10": 0.027178706104522946,
144
+ "eval_NanoNFCorpus_dot_precision@1": 0.02,
145
+ "eval_NanoNFCorpus_dot_precision@10": 0.026000000000000006,
146
+ "eval_NanoNFCorpus_dot_precision@3": 0.03333333333333333,
147
+ "eval_NanoNFCorpus_dot_precision@5": 0.032,
148
+ "eval_NanoNFCorpus_dot_recall@1": 7.905138339920947e-05,
149
+ "eval_NanoNFCorpus_dot_recall@10": 0.006349071275176555,
150
+ "eval_NanoNFCorpus_dot_recall@3": 0.003312410422185988,
151
+ "eval_NanoNFCorpus_dot_recall@5": 0.004545769460972766,
152
+ "eval_NanoNFCorpus_query_active_dims": 51200.0,
153
+ "eval_NanoNFCorpus_query_sparsity_ratio": 0.0,
154
+ "eval_base_loss": 2.2657,
155
+ "eval_document_regularizer_loss": 0.0006,
156
+ "eval_loss": 2.2663323879241943,
157
+ "eval_query_regularizer_loss": 0.0,
158
+ "eval_runtime": 364.216,
159
+ "eval_samples_per_second": 39.696,
160
+ "eval_steps_per_second": 0.621,
161
+ "step": 500
162
+ }
163
+ ],
164
+ "logging_steps": 50,
165
+ "max_steps": 10423,
166
+ "num_input_tokens_seen": 0,
167
+ "num_train_epochs": 1,
168
+ "save_steps": 500,
169
+ "stateful_callbacks": {
170
+ "TrainerControl": {
171
+ "args": {
172
+ "should_epoch_stop": false,
173
+ "should_evaluate": false,
174
+ "should_log": false,
175
+ "should_save": true,
176
+ "should_training_stop": false
177
+ },
178
+ "attributes": {}
179
+ }
180
+ },
181
+ "total_flos": 0.0,
182
+ "train_batch_size": 16,
183
+ "trial_name": null,
184
+ "trial_params": null
185
+ }
last-checkpoint/training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:80a238e2a37bc1da8ffa7f6ac3192a8f23e297afa1bc8b0a28a0d40a8e101359
3
+ size 6353