Training in progress, step 500, checkpoint
Browse files- last-checkpoint/1_SpladePooling/config.json +5 -0
- last-checkpoint/README.md +554 -0
- last-checkpoint/config.json +43 -0
- last-checkpoint/config_sentence_transformers.json +14 -0
- last-checkpoint/configuration_gptbert.py +34 -0
- last-checkpoint/model.safetensors +3 -0
- last-checkpoint/modeling_gptbert.py +1105 -0
- last-checkpoint/modules.json +14 -0
- last-checkpoint/optimizer.pt +3 -0
- last-checkpoint/rng_state_0.pth +3 -0
- last-checkpoint/rng_state_1.pth +3 -0
- last-checkpoint/scheduler.pt +3 -0
- last-checkpoint/sentence_bert_config.json +4 -0
- last-checkpoint/special_tokens_map.json +51 -0
- last-checkpoint/tokenizer.json +0 -0
- last-checkpoint/tokenizer_config.json +143 -0
- last-checkpoint/trainer_state.json +185 -0
- last-checkpoint/training_args.bin +3 -0
last-checkpoint/1_SpladePooling/config.json
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"pooling_strategy": "max",
|
| 3 |
+
"activation_function": "relu",
|
| 4 |
+
"word_embedding_dimension": 51200
|
| 5 |
+
}
|
last-checkpoint/README.md
ADDED
|
@@ -0,0 +1,554 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
language:
|
| 3 |
+
- 'no'
|
| 4 |
+
- da
|
| 5 |
+
- sv
|
| 6 |
+
license: mit
|
| 7 |
+
tags:
|
| 8 |
+
- sentence-transformers
|
| 9 |
+
- sparse-encoder
|
| 10 |
+
- sparse
|
| 11 |
+
- splade
|
| 12 |
+
- generated_from_trainer
|
| 13 |
+
- dataset_size:333547
|
| 14 |
+
- loss:SpladeLoss
|
| 15 |
+
- loss:SparseMultipleNegativesRankingLoss
|
| 16 |
+
- loss:FlopsLoss
|
| 17 |
+
base_model: ltg/norbert4-base
|
| 18 |
+
widget:
|
| 19 |
+
- text: "\n \nJeg begyndte at forstå, hvilke vældige kræfter min lille historie\
|
| 20 |
+
\ havde sluppet løs.\n \n "
|
| 21 |
+
- text: "\n \nIfølge Empires job-bibel skal en direktør-assistent ikke dække bord.\n\
|
| 22 |
+
\ \n "
|
| 23 |
+
- text: "\n \nDet kan du da ikke gøre!\n "
|
| 24 |
+
- text: "\n \nJeg må købe flere sherbet fountains.\n \n "
|
| 25 |
+
- text: Søren Kierkegaard, den danske filosof og teolog, var dybt fascineret af begrebet
|
| 26 |
+
tro. I sine mange skrifter udforskede han troens natur, dens paradokser og dens
|
| 27 |
+
betydning for det individuelle liv. Han anså troen for at være et ”spring i det
|
| 28 |
+
forlommede”, en akt af vilje der overstiger fornuften. I værker som ”Frygt og
|
| 29 |
+
Trekken” og ”Sygdommen til Døden” analyserede han troens relation til angst, desperation
|
| 30 |
+
og den eksistentielle krise. Kierkegaards tanker om tro har haft stor indflydelse
|
| 31 |
+
på kristen teologi og eksistentialisme.
|
| 32 |
+
pipeline_tag: feature-extraction
|
| 33 |
+
library_name: sentence-transformers
|
| 34 |
+
metrics:
|
| 35 |
+
- dot_accuracy@1
|
| 36 |
+
- dot_accuracy@3
|
| 37 |
+
- dot_accuracy@5
|
| 38 |
+
- dot_accuracy@10
|
| 39 |
+
- dot_precision@1
|
| 40 |
+
- dot_precision@3
|
| 41 |
+
- dot_precision@5
|
| 42 |
+
- dot_precision@10
|
| 43 |
+
- dot_recall@1
|
| 44 |
+
- dot_recall@3
|
| 45 |
+
- dot_recall@5
|
| 46 |
+
- dot_recall@10
|
| 47 |
+
- dot_ndcg@10
|
| 48 |
+
- dot_mrr@10
|
| 49 |
+
- dot_map@100
|
| 50 |
+
- query_active_dims
|
| 51 |
+
- query_sparsity_ratio
|
| 52 |
+
- corpus_active_dims
|
| 53 |
+
- corpus_sparsity_ratio
|
| 54 |
+
- avg_flops
|
| 55 |
+
model-index:
|
| 56 |
+
- name: Regular SPLADE NorBERT4-base — Retrieval-Only Training
|
| 57 |
+
results:
|
| 58 |
+
- task:
|
| 59 |
+
type: sparse-information-retrieval
|
| 60 |
+
name: Sparse Information Retrieval
|
| 61 |
+
dataset:
|
| 62 |
+
name: NanoNFCorpus
|
| 63 |
+
type: NanoNFCorpus
|
| 64 |
+
metrics:
|
| 65 |
+
- type: dot_accuracy@1
|
| 66 |
+
value: 0.02
|
| 67 |
+
name: Dot Accuracy@1
|
| 68 |
+
- type: dot_accuracy@3
|
| 69 |
+
value: 0.08
|
| 70 |
+
name: Dot Accuracy@3
|
| 71 |
+
- type: dot_accuracy@5
|
| 72 |
+
value: 0.08
|
| 73 |
+
name: Dot Accuracy@5
|
| 74 |
+
- type: dot_accuracy@10
|
| 75 |
+
value: 0.12
|
| 76 |
+
name: Dot Accuracy@10
|
| 77 |
+
- type: dot_precision@1
|
| 78 |
+
value: 0.02
|
| 79 |
+
name: Dot Precision@1
|
| 80 |
+
- type: dot_precision@3
|
| 81 |
+
value: 0.03333333333333333
|
| 82 |
+
name: Dot Precision@3
|
| 83 |
+
- type: dot_precision@5
|
| 84 |
+
value: 0.032
|
| 85 |
+
name: Dot Precision@5
|
| 86 |
+
- type: dot_precision@10
|
| 87 |
+
value: 0.026000000000000006
|
| 88 |
+
name: Dot Precision@10
|
| 89 |
+
- type: dot_recall@1
|
| 90 |
+
value: 7.905138339920947e-05
|
| 91 |
+
name: Dot Recall@1
|
| 92 |
+
- type: dot_recall@3
|
| 93 |
+
value: 0.003312410422185988
|
| 94 |
+
name: Dot Recall@3
|
| 95 |
+
- type: dot_recall@5
|
| 96 |
+
value: 0.004545769460972766
|
| 97 |
+
name: Dot Recall@5
|
| 98 |
+
- type: dot_recall@10
|
| 99 |
+
value: 0.006349071275176555
|
| 100 |
+
name: Dot Recall@10
|
| 101 |
+
- type: dot_ndcg@10
|
| 102 |
+
value: 0.027178706104522946
|
| 103 |
+
name: Dot Ndcg@10
|
| 104 |
+
- type: dot_mrr@10
|
| 105 |
+
value: 0.05088888888888889
|
| 106 |
+
name: Dot Mrr@10
|
| 107 |
+
- type: dot_map@100
|
| 108 |
+
value: 0.006747512755501429
|
| 109 |
+
name: Dot Map@100
|
| 110 |
+
- type: query_active_dims
|
| 111 |
+
value: 51200.0
|
| 112 |
+
name: Query Active Dims
|
| 113 |
+
- type: query_sparsity_ratio
|
| 114 |
+
value: 0.0
|
| 115 |
+
name: Query Sparsity Ratio
|
| 116 |
+
- type: corpus_active_dims
|
| 117 |
+
value: 51200.0
|
| 118 |
+
name: Corpus Active Dims
|
| 119 |
+
- type: corpus_sparsity_ratio
|
| 120 |
+
value: 0.0
|
| 121 |
+
name: Corpus Sparsity Ratio
|
| 122 |
+
- type: avg_flops
|
| 123 |
+
value: 51200.0
|
| 124 |
+
name: Avg Flops
|
| 125 |
+
---
|
| 126 |
+
|
| 127 |
+
# Regular SPLADE NorBERT4-base — Retrieval-Only Training
|
| 128 |
+
|
| 129 |
+
This is a [SPLADE Sparse Encoder](https://www.sbert.net/docs/sparse_encoder/usage/usage.html) model finetuned from [ltg/norbert4-base](https://huggingface.co/ltg/norbert4-base) using the [sentence-transformers](https://www.SBERT.net) library. It maps sentences & paragraphs to a 51200-dimensional sparse vector space and can be used for semantic search and sparse retrieval.
|
| 130 |
+
## Model Details
|
| 131 |
+
|
| 132 |
+
### Model Description
|
| 133 |
+
- **Model Type:** SPLADE Sparse Encoder
|
| 134 |
+
- **Base model:** [ltg/norbert4-base](https://huggingface.co/ltg/norbert4-base) <!-- at revision f04e0e824de9ff9a08767727dc8891d38fddd032 -->
|
| 135 |
+
- **Maximum Sequence Length:** None tokens
|
| 136 |
+
- **Output Dimensionality:** 51200 dimensions
|
| 137 |
+
- **Similarity Function:** Dot Product
|
| 138 |
+
<!-- - **Training Dataset:** Unknown -->
|
| 139 |
+
- **Languages:** no, da, sv
|
| 140 |
+
- **License:** mit
|
| 141 |
+
|
| 142 |
+
### Model Sources
|
| 143 |
+
|
| 144 |
+
- **Documentation:** [Sentence Transformers Documentation](https://sbert.net)
|
| 145 |
+
- **Documentation:** [Sparse Encoder Documentation](https://www.sbert.net/docs/sparse_encoder/usage/usage.html)
|
| 146 |
+
- **Repository:** [Sentence Transformers on GitHub](https://github.com/huggingface/sentence-transformers)
|
| 147 |
+
- **Hugging Face:** [Sparse Encoders on Hugging Face](https://huggingface.co/models?library=sentence-transformers&other=sparse-encoder)
|
| 148 |
+
|
| 149 |
+
### Full Model Architecture
|
| 150 |
+
|
| 151 |
+
```
|
| 152 |
+
SparseEncoder(
|
| 153 |
+
(0): MLMTransformer({'max_seq_length': None, 'do_lower_case': False, 'architecture': 'GptBertForMaskedLM'})
|
| 154 |
+
(1): SpladePooling({'pooling_strategy': 'max', 'activation_function': 'relu', 'word_embedding_dimension': 51200})
|
| 155 |
+
)
|
| 156 |
+
```
|
| 157 |
+
|
| 158 |
+
## Usage
|
| 159 |
+
|
| 160 |
+
### Direct Usage (Sentence Transformers)
|
| 161 |
+
|
| 162 |
+
First install the Sentence Transformers library:
|
| 163 |
+
|
| 164 |
+
```bash
|
| 165 |
+
pip install -U sentence-transformers
|
| 166 |
+
```
|
| 167 |
+
|
| 168 |
+
Then you can load this model and run inference.
|
| 169 |
+
```python
|
| 170 |
+
from sentence_transformers import SparseEncoder
|
| 171 |
+
|
| 172 |
+
# Download from the 🤗 Hub
|
| 173 |
+
model = SparseEncoder("thivy/norbert4-base-splade-retrieval")
|
| 174 |
+
# Run inference
|
| 175 |
+
sentences = [
|
| 176 |
+
'\n \nJeg vil ikke ha noen innvendinger.\n \n ',
|
| 177 |
+
'\n \nJeg ville ikke have nogen indvendinger.\n \n ',
|
| 178 |
+
'Søren Kierkegaard, den danske filosof og teolog, var dybt fascineret af begrebet tro. I sine mange skrifter udforskede han troens natur, dens paradokser og dens betydning for det individuelle liv. Han anså troen for at være et ”spring i det forlommede”, en akt af vilje der overstiger fornuften. I værker som ”Frygt og Trekken” og ”Sygdommen til Døden” analyserede han troens relation til angst, desperation og den eksistentielle krise. Kierkegaards tanker om tro har haft stor indflydelse på kristen teologi og eksistentialisme.',
|
| 179 |
+
]
|
| 180 |
+
embeddings = model.encode(sentences)
|
| 181 |
+
print(embeddings.shape)
|
| 182 |
+
# [3, 51200]
|
| 183 |
+
|
| 184 |
+
# Get the similarity scores for the embeddings
|
| 185 |
+
similarities = model.similarity(embeddings, embeddings)
|
| 186 |
+
print(similarities)
|
| 187 |
+
# tensor([[ 8.0400, 6.6640, 6.9193],
|
| 188 |
+
# [ 6.6640, 10.4033, 9.1223],
|
| 189 |
+
# [ 6.9193, 9.1223, 20.8932]])
|
| 190 |
+
```
|
| 191 |
+
|
| 192 |
+
<!--
|
| 193 |
+
### Direct Usage (Transformers)
|
| 194 |
+
|
| 195 |
+
<details><summary>Click to see the direct usage in Transformers</summary>
|
| 196 |
+
|
| 197 |
+
</details>
|
| 198 |
+
-->
|
| 199 |
+
|
| 200 |
+
<!--
|
| 201 |
+
### Downstream Usage (Sentence Transformers)
|
| 202 |
+
|
| 203 |
+
You can finetune this model on your own dataset.
|
| 204 |
+
|
| 205 |
+
<details><summary>Click to expand</summary>
|
| 206 |
+
|
| 207 |
+
</details>
|
| 208 |
+
-->
|
| 209 |
+
|
| 210 |
+
<!--
|
| 211 |
+
### Out-of-Scope Use
|
| 212 |
+
|
| 213 |
+
*List how the model may foreseeably be misused and address what users ought not to do with the model.*
|
| 214 |
+
-->
|
| 215 |
+
|
| 216 |
+
## Evaluation
|
| 217 |
+
|
| 218 |
+
### Metrics
|
| 219 |
+
|
| 220 |
+
#### Sparse Information Retrieval
|
| 221 |
+
|
| 222 |
+
* Dataset: `NanoNFCorpus`
|
| 223 |
+
* Evaluated with [<code>SparseInformationRetrievalEvaluator</code>](https://sbert.net/docs/package_reference/sparse_encoder/evaluation.html#sentence_transformers.sparse_encoder.evaluation.SparseInformationRetrievalEvaluator)
|
| 224 |
+
|
| 225 |
+
| Metric | Value |
|
| 226 |
+
|:----------------------|:-----------|
|
| 227 |
+
| dot_accuracy@1 | 0.02 |
|
| 228 |
+
| dot_accuracy@3 | 0.08 |
|
| 229 |
+
| dot_accuracy@5 | 0.08 |
|
| 230 |
+
| dot_accuracy@10 | 0.12 |
|
| 231 |
+
| dot_precision@1 | 0.02 |
|
| 232 |
+
| dot_precision@3 | 0.0333 |
|
| 233 |
+
| dot_precision@5 | 0.032 |
|
| 234 |
+
| dot_precision@10 | 0.026 |
|
| 235 |
+
| dot_recall@1 | 0.0001 |
|
| 236 |
+
| dot_recall@3 | 0.0033 |
|
| 237 |
+
| dot_recall@5 | 0.0045 |
|
| 238 |
+
| dot_recall@10 | 0.0063 |
|
| 239 |
+
| **dot_ndcg@10** | **0.0272** |
|
| 240 |
+
| dot_mrr@10 | 0.0509 |
|
| 241 |
+
| dot_map@100 | 0.0067 |
|
| 242 |
+
| query_active_dims | 51200.0 |
|
| 243 |
+
| query_sparsity_ratio | 0.0 |
|
| 244 |
+
| corpus_active_dims | 51200.0 |
|
| 245 |
+
| corpus_sparsity_ratio | 0.0 |
|
| 246 |
+
| avg_flops | 51200.0 |
|
| 247 |
+
|
| 248 |
+
<!--
|
| 249 |
+
## Bias, Risks and Limitations
|
| 250 |
+
|
| 251 |
+
*What are the known or foreseeable issues stemming from this model? You could also flag here known failure cases or weaknesses of the model.*
|
| 252 |
+
-->
|
| 253 |
+
|
| 254 |
+
<!--
|
| 255 |
+
### Recommendations
|
| 256 |
+
|
| 257 |
+
*What are recommendations with respect to the foreseeable issues? For example, filtering explicit content.*
|
| 258 |
+
-->
|
| 259 |
+
|
| 260 |
+
## Training Details
|
| 261 |
+
|
| 262 |
+
### Training Dataset
|
| 263 |
+
|
| 264 |
+
#### Unnamed Dataset
|
| 265 |
+
|
| 266 |
+
* Size: 333,547 training samples
|
| 267 |
+
* Columns: <code>anchor</code> and <code>positive</code>
|
| 268 |
+
* Approximate statistics based on the first 1000 samples:
|
| 269 |
+
| | anchor | positive |
|
| 270 |
+
|:--------|:-----------------------------------------------------------------------------------|:-------------------------------------------------------------------------------------|
|
| 271 |
+
| type | string | string |
|
| 272 |
+
| details | <ul><li>min: 3 tokens</li><li>mean: 22.81 tokens</li><li>max: 517 tokens</li></ul> | <ul><li>min: 1 tokens</li><li>mean: 406.29 tokens</li><li>max: 4096 tokens</li></ul> |
|
| 273 |
+
* Samples:
|
| 274 |
+
| anchor | positive |
|
| 275 |
+
|:---------------------------------------------------------------------|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
| 276 |
+
| <code><br>Hun er mye eldre enn henne.<br> <br> </code> | <code><br>Hun er meget ældre end hende.<br> <br> </code> |
|
| 277 |
+
| <code><br>Hva så? <br> <br>Du lå med kona mi!<br> <br> </code> | <code><br>Men du gik i seng med min kone.<br> <br> </code> |
|
| 278 |
+
| <code>Hur aktiverar jag en indeksfond?</code> | <code>Att investera i indexfonder är ett populärt sätt att exponera sig mot aktiemarknaden. Det är ett passivt investeringsalternativ där portföljen följer en specifik index, till exempel OMX Stockholm 30.<br><br>För att aktivera en indexfond behöver du ett depåkonto hos en bank eller en investmentsmäklare. Innan du påbörjar processen bör du noggrant undersöka och jämföra olika fonder för att hitta den som bäst passar dina investeringsmål och risktolerans.<br><br>När du väl har valt en fond kan du vanligtvis aktivera den online via bankens eller mäklarens plattform. Du behöver ange hur mycket du vill investera och godkänna villkoren. Därefter kommer fonden att köpas och lagts till i ditt depåkonto.<br><br>Det är viktigt att ha en långsiktig investeringshorisont när du investerar i indexfonder. Marknaderna fluktuerar i värde på kort sikt, men över tid har indexfonder historiskt sett genererat goda avkastningar.</code> |
|
| 279 |
+
* Loss: [<code>SpladeLoss</code>](https://sbert.net/docs/package_reference/sparse_encoder/losses.html#spladeloss) with these parameters:
|
| 280 |
+
```json
|
| 281 |
+
{
|
| 282 |
+
"loss": "SparseMultipleNegativesRankingLoss(scale=1.0, similarity_fct='dot_score', gather_across_devices=False)",
|
| 283 |
+
"document_regularizer_weight": 0.003,
|
| 284 |
+
"query_regularizer_weight": 0.0001
|
| 285 |
+
}
|
| 286 |
+
```
|
| 287 |
+
|
| 288 |
+
### Evaluation Dataset
|
| 289 |
+
|
| 290 |
+
#### Unnamed Dataset
|
| 291 |
+
|
| 292 |
+
* Size: 14,458 evaluation samples
|
| 293 |
+
* Columns: <code>anchor</code> and <code>positive</code>
|
| 294 |
+
* Approximate statistics based on the first 1000 samples:
|
| 295 |
+
| | anchor | positive |
|
| 296 |
+
|:--------|:----------------------------------------------------------------------------------|:-------------------------------------------------------------------------------------|
|
| 297 |
+
| type | string | string |
|
| 298 |
+
| details | <ul><li>min: 3 tokens</li><li>mean: 16.03 tokens</li><li>max: 86 tokens</li></ul> | <ul><li>min: 7 tokens</li><li>mean: 134.75 tokens</li><li>max: 4096 tokens</li></ul> |
|
| 299 |
+
* Samples:
|
| 300 |
+
| anchor | positive |
|
| 301 |
+
|:--------------------------------------------------------------------------------------------|:------------------------------------------------------------------------------------|
|
| 302 |
+
| <code><br> <br>Hva er det for organisasjon som skal ha årsmøte her?<br> <br> </code> | <code><br> <br>Hvilken organisation skal holde kongres her?<br> <br> </code> |
|
| 303 |
+
| <code><br>Livet ditt er jo ikke så verst.<br> <br> </code> | <code><br>Dit liv er ikke så slemt.<br> <br> </code> |
|
| 304 |
+
| <code><br> <br>Men du må ta deg av dem for meg, okay?<br> <br> </code> | <code><br> <br>Men du må tage dig af dem for mig, okay?<br> <br> </code> |
|
| 305 |
+
* Loss: [<code>SpladeLoss</code>](https://sbert.net/docs/package_reference/sparse_encoder/losses.html#spladeloss) with these parameters:
|
| 306 |
+
```json
|
| 307 |
+
{
|
| 308 |
+
"loss": "SparseMultipleNegativesRankingLoss(scale=1.0, similarity_fct='dot_score', gather_across_devices=False)",
|
| 309 |
+
"document_regularizer_weight": 0.003,
|
| 310 |
+
"query_regularizer_weight": 0.0001
|
| 311 |
+
}
|
| 312 |
+
```
|
| 313 |
+
|
| 314 |
+
### Training Hyperparameters
|
| 315 |
+
#### Non-Default Hyperparameters
|
| 316 |
+
|
| 317 |
+
- `eval_strategy`: steps
|
| 318 |
+
- `per_device_train_batch_size`: 16
|
| 319 |
+
- `per_device_eval_batch_size`: 32
|
| 320 |
+
- `learning_rate`: 2e-05
|
| 321 |
+
- `weight_decay`: 0.01
|
| 322 |
+
- `num_train_epochs`: 1
|
| 323 |
+
- `warmup_ratio`: 0.1
|
| 324 |
+
- `bf16`: True
|
| 325 |
+
- `dataloader_num_workers`: 2
|
| 326 |
+
- `dataloader_prefetch_factor`: 2
|
| 327 |
+
- `load_best_model_at_end`: True
|
| 328 |
+
- `ddp_find_unused_parameters`: True
|
| 329 |
+
- `push_to_hub`: True
|
| 330 |
+
- `hub_model_id`: thivy/norbert4-base-splade-retrieval
|
| 331 |
+
- `hub_strategy`: checkpoint
|
| 332 |
+
- `hub_private_repo`: False
|
| 333 |
+
- `gradient_checkpointing`: True
|
| 334 |
+
- `gradient_checkpointing_kwargs`: {'use_reentrant': False}
|
| 335 |
+
- `multi_dataset_batch_sampler`: round_robin
|
| 336 |
+
|
| 337 |
+
#### All Hyperparameters
|
| 338 |
+
<details><summary>Click to expand</summary>
|
| 339 |
+
|
| 340 |
+
- `overwrite_output_dir`: False
|
| 341 |
+
- `do_predict`: False
|
| 342 |
+
- `eval_strategy`: steps
|
| 343 |
+
- `prediction_loss_only`: True
|
| 344 |
+
- `per_device_train_batch_size`: 16
|
| 345 |
+
- `per_device_eval_batch_size`: 32
|
| 346 |
+
- `per_gpu_train_batch_size`: None
|
| 347 |
+
- `per_gpu_eval_batch_size`: None
|
| 348 |
+
- `gradient_accumulation_steps`: 1
|
| 349 |
+
- `eval_accumulation_steps`: None
|
| 350 |
+
- `torch_empty_cache_steps`: None
|
| 351 |
+
- `learning_rate`: 2e-05
|
| 352 |
+
- `weight_decay`: 0.01
|
| 353 |
+
- `adam_beta1`: 0.9
|
| 354 |
+
- `adam_beta2`: 0.999
|
| 355 |
+
- `adam_epsilon`: 1e-08
|
| 356 |
+
- `max_grad_norm`: 1.0
|
| 357 |
+
- `num_train_epochs`: 1
|
| 358 |
+
- `max_steps`: -1
|
| 359 |
+
- `lr_scheduler_type`: linear
|
| 360 |
+
- `lr_scheduler_kwargs`: {}
|
| 361 |
+
- `warmup_ratio`: 0.1
|
| 362 |
+
- `warmup_steps`: 0
|
| 363 |
+
- `log_level`: passive
|
| 364 |
+
- `log_level_replica`: warning
|
| 365 |
+
- `log_on_each_node`: True
|
| 366 |
+
- `logging_nan_inf_filter`: True
|
| 367 |
+
- `save_safetensors`: True
|
| 368 |
+
- `save_on_each_node`: False
|
| 369 |
+
- `save_only_model`: False
|
| 370 |
+
- `restore_callback_states_from_checkpoint`: False
|
| 371 |
+
- `no_cuda`: False
|
| 372 |
+
- `use_cpu`: False
|
| 373 |
+
- `use_mps_device`: False
|
| 374 |
+
- `seed`: 42
|
| 375 |
+
- `data_seed`: None
|
| 376 |
+
- `jit_mode_eval`: False
|
| 377 |
+
- `bf16`: True
|
| 378 |
+
- `fp16`: False
|
| 379 |
+
- `fp16_opt_level`: O1
|
| 380 |
+
- `half_precision_backend`: auto
|
| 381 |
+
- `bf16_full_eval`: False
|
| 382 |
+
- `fp16_full_eval`: False
|
| 383 |
+
- `tf32`: None
|
| 384 |
+
- `local_rank`: 0
|
| 385 |
+
- `ddp_backend`: None
|
| 386 |
+
- `tpu_num_cores`: None
|
| 387 |
+
- `tpu_metrics_debug`: False
|
| 388 |
+
- `debug`: []
|
| 389 |
+
- `dataloader_drop_last`: True
|
| 390 |
+
- `dataloader_num_workers`: 2
|
| 391 |
+
- `dataloader_prefetch_factor`: 2
|
| 392 |
+
- `past_index`: -1
|
| 393 |
+
- `disable_tqdm`: False
|
| 394 |
+
- `remove_unused_columns`: True
|
| 395 |
+
- `label_names`: None
|
| 396 |
+
- `load_best_model_at_end`: True
|
| 397 |
+
- `ignore_data_skip`: False
|
| 398 |
+
- `fsdp`: []
|
| 399 |
+
- `fsdp_min_num_params`: 0
|
| 400 |
+
- `fsdp_config`: {'min_num_params': 0, 'xla': False, 'xla_fsdp_v2': False, 'xla_fsdp_grad_ckpt': False}
|
| 401 |
+
- `fsdp_transformer_layer_cls_to_wrap`: None
|
| 402 |
+
- `accelerator_config`: {'split_batches': False, 'dispatch_batches': None, 'even_batches': True, 'use_seedable_sampler': True, 'non_blocking': False, 'gradient_accumulation_kwargs': None}
|
| 403 |
+
- `parallelism_config`: None
|
| 404 |
+
- `deepspeed`: None
|
| 405 |
+
- `label_smoothing_factor`: 0.0
|
| 406 |
+
- `optim`: adamw_torch_fused
|
| 407 |
+
- `optim_args`: None
|
| 408 |
+
- `adafactor`: False
|
| 409 |
+
- `group_by_length`: False
|
| 410 |
+
- `length_column_name`: length
|
| 411 |
+
- `project`: huggingface
|
| 412 |
+
- `trackio_space_id`: trackio
|
| 413 |
+
- `ddp_find_unused_parameters`: True
|
| 414 |
+
- `ddp_bucket_cap_mb`: None
|
| 415 |
+
- `ddp_broadcast_buffers`: False
|
| 416 |
+
- `dataloader_pin_memory`: True
|
| 417 |
+
- `dataloader_persistent_workers`: False
|
| 418 |
+
- `skip_memory_metrics`: True
|
| 419 |
+
- `use_legacy_prediction_loop`: False
|
| 420 |
+
- `push_to_hub`: True
|
| 421 |
+
- `resume_from_checkpoint`: None
|
| 422 |
+
- `hub_model_id`: thivy/norbert4-base-splade-retrieval
|
| 423 |
+
- `hub_strategy`: checkpoint
|
| 424 |
+
- `hub_private_repo`: False
|
| 425 |
+
- `hub_always_push`: False
|
| 426 |
+
- `hub_revision`: None
|
| 427 |
+
- `gradient_checkpointing`: True
|
| 428 |
+
- `gradient_checkpointing_kwargs`: {'use_reentrant': False}
|
| 429 |
+
- `include_inputs_for_metrics`: False
|
| 430 |
+
- `include_for_metrics`: []
|
| 431 |
+
- `eval_do_concat_batches`: True
|
| 432 |
+
- `fp16_backend`: auto
|
| 433 |
+
- `push_to_hub_model_id`: None
|
| 434 |
+
- `push_to_hub_organization`: None
|
| 435 |
+
- `mp_parameters`:
|
| 436 |
+
- `auto_find_batch_size`: False
|
| 437 |
+
- `full_determinism`: False
|
| 438 |
+
- `torchdynamo`: None
|
| 439 |
+
- `ray_scope`: last
|
| 440 |
+
- `ddp_timeout`: 1800
|
| 441 |
+
- `torch_compile`: False
|
| 442 |
+
- `torch_compile_backend`: None
|
| 443 |
+
- `torch_compile_mode`: None
|
| 444 |
+
- `include_tokens_per_second`: False
|
| 445 |
+
- `include_num_input_tokens_seen`: no
|
| 446 |
+
- `neftune_noise_alpha`: None
|
| 447 |
+
- `optim_target_modules`: None
|
| 448 |
+
- `batch_eval_metrics`: False
|
| 449 |
+
- `eval_on_start`: False
|
| 450 |
+
- `use_liger_kernel`: False
|
| 451 |
+
- `liger_kernel_config`: None
|
| 452 |
+
- `eval_use_gather_object`: False
|
| 453 |
+
- `average_tokens_across_devices`: True
|
| 454 |
+
- `prompts`: None
|
| 455 |
+
- `batch_sampler`: batch_sampler
|
| 456 |
+
- `multi_dataset_batch_sampler`: round_robin
|
| 457 |
+
- `router_mapping`: {}
|
| 458 |
+
- `learning_rate_mapping`: {}
|
| 459 |
+
|
| 460 |
+
</details>
|
| 461 |
+
|
| 462 |
+
### Training Logs
|
| 463 |
+
| Epoch | Step | Training Loss | Validation Loss | NanoNFCorpus_dot_ndcg@10 |
|
| 464 |
+
|:------:|:----:|:-------------:|:---------------:|:------------------------:|
|
| 465 |
+
| 0.0048 | 50 | 37895.69 | - | - |
|
| 466 |
+
| 0.0096 | 100 | 10002.0562 | - | - |
|
| 467 |
+
| 0.0144 | 150 | 3805.4731 | - | - |
|
| 468 |
+
| 0.0192 | 200 | 923.0944 | - | - |
|
| 469 |
+
| 0.0240 | 250 | 514.7795 | - | - |
|
| 470 |
+
| 0.0288 | 300 | 284.5449 | - | - |
|
| 471 |
+
| 0.0336 | 350 | 90.0678 | - | - |
|
| 472 |
+
| 0.0384 | 400 | 30.8482 | - | - |
|
| 473 |
+
| 0.0432 | 450 | 2.5071 | - | - |
|
| 474 |
+
| 0.0480 | 500 | 1.3525 | 2.2663 | 0.0272 |
|
| 475 |
+
|
| 476 |
+
|
| 477 |
+
### Framework Versions
|
| 478 |
+
- Python: 3.12.12
|
| 479 |
+
- Sentence Transformers: 5.2.0
|
| 480 |
+
- Transformers: 4.57.3
|
| 481 |
+
- PyTorch: 2.9.1+cu128
|
| 482 |
+
- Accelerate: 1.12.0
|
| 483 |
+
- Datasets: 4.4.2
|
| 484 |
+
- Tokenizers: 0.22.2
|
| 485 |
+
|
| 486 |
+
## Citation
|
| 487 |
+
|
| 488 |
+
### BibTeX
|
| 489 |
+
|
| 490 |
+
#### Sentence Transformers
|
| 491 |
+
```bibtex
|
| 492 |
+
@inproceedings{reimers-2019-sentence-bert,
|
| 493 |
+
title = "Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks",
|
| 494 |
+
author = "Reimers, Nils and Gurevych, Iryna",
|
| 495 |
+
booktitle = "Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing",
|
| 496 |
+
month = "11",
|
| 497 |
+
year = "2019",
|
| 498 |
+
publisher = "Association for Computational Linguistics",
|
| 499 |
+
url = "https://arxiv.org/abs/1908.10084",
|
| 500 |
+
}
|
| 501 |
+
```
|
| 502 |
+
|
| 503 |
+
#### SpladeLoss
|
| 504 |
+
```bibtex
|
| 505 |
+
@misc{formal2022distillationhardnegativesampling,
|
| 506 |
+
title={From Distillation to Hard Negative Sampling: Making Sparse Neural IR Models More Effective},
|
| 507 |
+
author={Thibault Formal and Carlos Lassance and Benjamin Piwowarski and Stéphane Clinchant},
|
| 508 |
+
year={2022},
|
| 509 |
+
eprint={2205.04733},
|
| 510 |
+
archivePrefix={arXiv},
|
| 511 |
+
primaryClass={cs.IR},
|
| 512 |
+
url={https://arxiv.org/abs/2205.04733},
|
| 513 |
+
}
|
| 514 |
+
```
|
| 515 |
+
|
| 516 |
+
#### SparseMultipleNegativesRankingLoss
|
| 517 |
+
```bibtex
|
| 518 |
+
@misc{henderson2017efficient,
|
| 519 |
+
title={Efficient Natural Language Response Suggestion for Smart Reply},
|
| 520 |
+
author={Matthew Henderson and Rami Al-Rfou and Brian Strope and Yun-hsuan Sung and Laszlo Lukacs and Ruiqi Guo and Sanjiv Kumar and Balint Miklos and Ray Kurzweil},
|
| 521 |
+
year={2017},
|
| 522 |
+
eprint={1705.00652},
|
| 523 |
+
archivePrefix={arXiv},
|
| 524 |
+
primaryClass={cs.CL}
|
| 525 |
+
}
|
| 526 |
+
```
|
| 527 |
+
|
| 528 |
+
#### FlopsLoss
|
| 529 |
+
```bibtex
|
| 530 |
+
@article{paria2020minimizing,
|
| 531 |
+
title={Minimizing flops to learn efficient sparse representations},
|
| 532 |
+
author={Paria, Biswajit and Yeh, Chih-Kuan and Yen, Ian EH and Xu, Ning and Ravikumar, Pradeep and P{'o}czos, Barnab{'a}s},
|
| 533 |
+
journal={arXiv preprint arXiv:2004.05665},
|
| 534 |
+
year={2020}
|
| 535 |
+
}
|
| 536 |
+
```
|
| 537 |
+
|
| 538 |
+
<!--
|
| 539 |
+
## Glossary
|
| 540 |
+
|
| 541 |
+
*Clearly define terms in order to be accessible across audiences.*
|
| 542 |
+
-->
|
| 543 |
+
|
| 544 |
+
<!--
|
| 545 |
+
## Model Card Authors
|
| 546 |
+
|
| 547 |
+
*Lists the people who create the model card, providing recognition and accountability for the detailed work that goes into its construction.*
|
| 548 |
+
-->
|
| 549 |
+
|
| 550 |
+
<!--
|
| 551 |
+
## Model Card Contact
|
| 552 |
+
|
| 553 |
+
*Provides a way for people who have updates to the Model Card, suggestions, or questions, to contact the Model Card authors.*
|
| 554 |
+
-->
|
last-checkpoint/config.json
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"GptBertForMaskedLM"
|
| 4 |
+
],
|
| 5 |
+
"attention_dropout": 0.0,
|
| 6 |
+
"attn_implementation": null,
|
| 7 |
+
"auto_map": {
|
| 8 |
+
"AutoConfig": "configuration_gptbert.GptBertConfig",
|
| 9 |
+
"AutoModel": "modeling_gptbert.GptBertModel",
|
| 10 |
+
"AutoModelForCausalLM": "modeling_gptbert.GptBertForCausalLM",
|
| 11 |
+
"AutoModelForMaskedLM": "modeling_gptbert.GptBertForMaskedLM",
|
| 12 |
+
"AutoModelForMultipleChoice": "modeling_gptbert.GptBertForMultipleChoice",
|
| 13 |
+
"AutoModelForQuestionAnswering": "modeling_gptbert.GptBertForQuestionAnswering",
|
| 14 |
+
"AutoModelForSequenceClassification": "modeling_gptbert.GptBertForSequenceClassification",
|
| 15 |
+
"AutoModelForTokenClassification": "modeling_gptbert.GptBertForTokenClassification"
|
| 16 |
+
},
|
| 17 |
+
"bos_token_id": 1,
|
| 18 |
+
"classifier_dropout": 0.2,
|
| 19 |
+
"deterministic_flash_attn": false,
|
| 20 |
+
"dtype": "float32",
|
| 21 |
+
"embedding_dropout": 0.1,
|
| 22 |
+
"eos_token_id": 2,
|
| 23 |
+
"global_window_length": 8192,
|
| 24 |
+
"hidden_dropout": 0.0,
|
| 25 |
+
"hidden_size": 640,
|
| 26 |
+
"intermediate_size": 1664,
|
| 27 |
+
"layer_norm_eps": 1e-07,
|
| 28 |
+
"local_global_ratio": 4,
|
| 29 |
+
"local_window_length": 256,
|
| 30 |
+
"mask_token_id": 4,
|
| 31 |
+
"max_sequence_length": 16384,
|
| 32 |
+
"model": "norbert4",
|
| 33 |
+
"num_attention_heads": 10,
|
| 34 |
+
"num_layers": 24,
|
| 35 |
+
"pad_token_id": 3,
|
| 36 |
+
"query_key_head_size": 64,
|
| 37 |
+
"rope_theta": 160000,
|
| 38 |
+
"transformers_version": "4.57.3",
|
| 39 |
+
"unk_token_id": 0,
|
| 40 |
+
"use_cache": false,
|
| 41 |
+
"value_head_size": 64,
|
| 42 |
+
"vocab_size": 51200
|
| 43 |
+
}
|
last-checkpoint/config_sentence_transformers.json
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_type": "SparseEncoder",
|
| 3 |
+
"__version__": {
|
| 4 |
+
"sentence_transformers": "5.2.0",
|
| 5 |
+
"transformers": "4.57.3",
|
| 6 |
+
"pytorch": "2.9.1+cu128"
|
| 7 |
+
},
|
| 8 |
+
"prompts": {
|
| 9 |
+
"query": "",
|
| 10 |
+
"document": ""
|
| 11 |
+
},
|
| 12 |
+
"default_prompt_name": null,
|
| 13 |
+
"similarity_fn_name": "dot"
|
| 14 |
+
}
|
last-checkpoint/configuration_gptbert.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
import copy
|
| 6 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class GptBertConfig(PretrainedConfig):
|
| 10 |
+
|
| 11 |
+
def __init__(
|
| 12 |
+
self,
|
| 13 |
+
config_file: Path | str | None = None,
|
| 14 |
+
**kwargs
|
| 15 |
+
):
|
| 16 |
+
super().__init__(**kwargs)
|
| 17 |
+
self.model = "norbert4"
|
| 18 |
+
|
| 19 |
+
if config_file is not None:
|
| 20 |
+
if type(config_file) is str:
|
| 21 |
+
config_file = Path(config_file)
|
| 22 |
+
assert type(config_file) is not Path, "The config_file should either be a Path or str"
|
| 23 |
+
with config_file.open("r") as file:
|
| 24 |
+
config = json.load(file)
|
| 25 |
+
|
| 26 |
+
for attr, value in config.items():
|
| 27 |
+
if isinstance(value, str):
|
| 28 |
+
value = value.lower()
|
| 29 |
+
setattr(self, attr, value)
|
| 30 |
+
|
| 31 |
+
for attr, value in kwargs.items():
|
| 32 |
+
if isinstance(value, str):
|
| 33 |
+
value = value.lower()
|
| 34 |
+
setattr(self, attr, value)
|
last-checkpoint/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ca6b59b6342fcd6a1910b237e5db7707f98673239940cfc25a5d1876082ebc33
|
| 3 |
+
size 728561776
|
last-checkpoint/modeling_gptbert.py
ADDED
|
@@ -0,0 +1,1105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from torch.nn import functional as F
|
| 6 |
+
from torch import _softmax_backward_data as _softmax_backward_data
|
| 7 |
+
|
| 8 |
+
from functools import partial, lru_cache
|
| 9 |
+
|
| 10 |
+
from .configuration_gptbert import GptBertConfig
|
| 11 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 12 |
+
from transformers.activations import gelu_new
|
| 13 |
+
from transformers.utils import is_flash_attn_2_available, logging
|
| 14 |
+
from transformers.modeling_outputs import (
|
| 15 |
+
MaskedLMOutput,
|
| 16 |
+
MultipleChoiceModelOutput,
|
| 17 |
+
QuestionAnsweringModelOutput,
|
| 18 |
+
SequenceClassifierOutput,
|
| 19 |
+
TokenClassifierOutput,
|
| 20 |
+
BaseModelOutput,
|
| 21 |
+
CausalLMOutput
|
| 22 |
+
)
|
| 23 |
+
import math
|
| 24 |
+
from typing import TYPE_CHECKING, Optional, Union, Tuple, List
|
| 25 |
+
|
| 26 |
+
logger = logging.get_logger(__name__)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# Workaround for transformers < 4.36.0 check_imports issue
|
| 30 |
+
# See: https://github.com/huggingface/transformers/issues/28459
|
| 31 |
+
try:
|
| 32 |
+
if is_flash_attn_2_available():
|
| 33 |
+
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
|
| 34 |
+
from flash_attn.layers.rotary import RotaryEmbedding
|
| 35 |
+
from flash_attn.ops.triton.rotary import apply_rotary
|
| 36 |
+
else:
|
| 37 |
+
flash_attn_varlen_qkvpacked_func, RotaryEmbedding, apply_rotary = None, object, None
|
| 38 |
+
logger.warning_once(
|
| 39 |
+
"NorBERT4 støtter FlashAttention, men det er ikke funnet i miljøet ditt. Du bør vurdere å oppdatere miljøet ditt for å få raskere og mindre minnekrevende behandling."
|
| 40 |
+
)
|
| 41 |
+
except ImportError:
|
| 42 |
+
flash_attn_varlen_qkvpacked_func, RotaryEmbedding, apply_rotary = None, object, None
|
| 43 |
+
logger.warning_once(
|
| 44 |
+
"NorBERT4 støtter FlashAttention, men det er ikke funnet i miljøet ditt. Du bør vurdere å oppdatere miljøet ditt for å få raskere og mindre minnekrevende behandling."
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# from https://github.com/huggingface/transformers/blob/main/src/transformers/models/modernbert/modeling_modernbert.py
|
| 49 |
+
@torch.compiler.disable()
|
| 50 |
+
def _unpad_input(input_ids: torch.Tensor, attention_mask: torch.Tensor):
|
| 51 |
+
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
| 52 |
+
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
| 53 |
+
max_seqlen_in_batch = int(seqlens_in_batch.max().item())
|
| 54 |
+
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
| 55 |
+
|
| 56 |
+
if input_ids.dim() == 2:
|
| 57 |
+
unpadded_inputs = input_ids.flatten()[indices]
|
| 58 |
+
else:
|
| 59 |
+
batch_size, sequence_length, *rest = input_ids.shape
|
| 60 |
+
shape = batch_size * sequence_length
|
| 61 |
+
unpadded_inputs = input_ids.view(shape, *rest)[indices]
|
| 62 |
+
|
| 63 |
+
return unpadded_inputs, indices, cu_seqlens, max_seqlen_in_batch
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
# from https://github.com/huggingface/transformers/blob/main/src/transformers/models/modernbert/modeling_modernbert.py
|
| 67 |
+
def _pad_output(input_ids: torch.Tensor, indices: torch.Tensor, batch_size: int, sequence_length: int) -> torch.Tensor:
|
| 68 |
+
if input_ids.dim() == 1:
|
| 69 |
+
output = torch.zeros(batch_size * sequence_length, dtype=input_ids.dtype, device=input_ids.device)
|
| 70 |
+
output[indices] = input_ids
|
| 71 |
+
padded_inputs = output.view(batch_size, sequence_length)
|
| 72 |
+
else:
|
| 73 |
+
_, *rest = input_ids.shape
|
| 74 |
+
output = torch.zeros(batch_size * sequence_length, *rest, dtype=input_ids.dtype, device=input_ids.device)
|
| 75 |
+
output[indices] = input_ids
|
| 76 |
+
padded_inputs = output.view(batch_size, sequence_length, *rest)
|
| 77 |
+
|
| 78 |
+
return padded_inputs
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class CastedLinear(nn.Linear):
|
| 82 |
+
def __init__(self, in_features, out_features, bias):
|
| 83 |
+
super().__init__(in_features, out_features, bias=bias)
|
| 84 |
+
|
| 85 |
+
def forward(self, x):
|
| 86 |
+
return F.linear(x, self.weight.type_as(x), bias=self.bias.type_as(x) if self.bias is not None else None)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class CastedLinearIn(nn.Linear):
|
| 90 |
+
def __init__(self, in_features, out_features, bias):
|
| 91 |
+
super().__init__(in_features, out_features, bias=bias)
|
| 92 |
+
self.scale = nn.Parameter(torch.ones(in_features))
|
| 93 |
+
|
| 94 |
+
def forward(self, x):
|
| 95 |
+
return F.linear(x, (self.weight * (self.scale + 1.0).unsqueeze(0)).type_as(x), bias=self.bias.type_as(x) if self.bias is not None else None)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class MultiCastedLinearOrthoIn(nn.Module):
|
| 99 |
+
def __init__(self, in_features, out_features, bias):
|
| 100 |
+
super().__init__()
|
| 101 |
+
|
| 102 |
+
self.in_features = in_features
|
| 103 |
+
self.out_features = out_features
|
| 104 |
+
|
| 105 |
+
self.weights = nn.ParameterList()
|
| 106 |
+
for out_feature in out_features:
|
| 107 |
+
self.weights.append(nn.Parameter(torch.empty((out_feature, in_features))))
|
| 108 |
+
|
| 109 |
+
if bias:
|
| 110 |
+
self.bias = nn.Parameter(torch.zeros(sum(out_features)))
|
| 111 |
+
else:
|
| 112 |
+
self.bias = self.register_parameter("bias", None)
|
| 113 |
+
|
| 114 |
+
self.scale = nn.Parameter(torch.ones(in_features))
|
| 115 |
+
|
| 116 |
+
def forward(self, x):
|
| 117 |
+
return F.linear(x, (torch.cat([weight for weight in self.weights], dim=0) * (self.scale + 1.0).unsqueeze(0)).type_as(x), bias=self.bias.type_as(x) if self.bias is not None else None)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
class GeGLU(nn.Module):
|
| 121 |
+
def forward(self, x):
|
| 122 |
+
x, gate = x.chunk(2, dim=-1)
|
| 123 |
+
return x * gelu_new(gate)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
class Embedding(nn.Module):
|
| 127 |
+
def __init__(self, config: GptBertConfig):
|
| 128 |
+
super().__init__()
|
| 129 |
+
|
| 130 |
+
self.word_embedding = nn.Embedding(config.vocab_size, config.hidden_size)
|
| 131 |
+
self.word_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False, bias=False)
|
| 132 |
+
self.word_scale = nn.Parameter(torch.zeros(config.hidden_size))
|
| 133 |
+
self.dropout = nn.Dropout(config.embedding_dropout)
|
| 134 |
+
|
| 135 |
+
def forward(self, input_ids: torch.Tensor):
|
| 136 |
+
word_embedding = self.word_embedding(input_ids)
|
| 137 |
+
word_embedding = self.word_norm(word_embedding)
|
| 138 |
+
word_embedding = word_embedding * (self.word_scale + 1.0)
|
| 139 |
+
|
| 140 |
+
return self.dropout(word_embedding)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
class LMClassifier(nn.Module):
|
| 144 |
+
def __init__(self, config: GptBertConfig, n_labels: int):
|
| 145 |
+
super().__init__()
|
| 146 |
+
|
| 147 |
+
self.pre_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False)
|
| 148 |
+
self.projection = CastedLinearIn(config.hidden_size, config.hidden_size, bias=False)
|
| 149 |
+
self.post_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False)
|
| 150 |
+
self.emb2vocab = CastedLinearIn(config.hidden_size, n_labels, bias=True)
|
| 151 |
+
|
| 152 |
+
def forward(self, x: torch.Tensor):
|
| 153 |
+
x = self.pre_norm(x.float()).type_as(x)
|
| 154 |
+
x = self.projection(x)
|
| 155 |
+
x = gelu_new(x)
|
| 156 |
+
x = self.post_norm(x.float()).type_as(x)
|
| 157 |
+
x = self.emb2vocab(x)
|
| 158 |
+
return x
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
class Classifier(nn.Module):
|
| 162 |
+
def __init__(self, config: GptBertConfig, n_labels: int):
|
| 163 |
+
super().__init__()
|
| 164 |
+
|
| 165 |
+
self.pre_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False)
|
| 166 |
+
self.projection = CastedLinearIn(config.hidden_size, config.hidden_size, bias=False)
|
| 167 |
+
self.post_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False)
|
| 168 |
+
self.dropout = nn.Dropout(config.classifier_dropout)
|
| 169 |
+
self.output_projection = CastedLinearIn(config.hidden_size, n_labels, bias=True)
|
| 170 |
+
|
| 171 |
+
def forward(self, x: torch.Tensor):
|
| 172 |
+
x = self.pre_norm(x.float()).type_as(x)
|
| 173 |
+
x = self.projection(x)
|
| 174 |
+
x = gelu_new(x)
|
| 175 |
+
x = self.post_norm(x.float()).type_as(x)
|
| 176 |
+
x = self.dropout(x)
|
| 177 |
+
x = self.output_projection(x)
|
| 178 |
+
return x
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
# from https://github.com/huggingface/transformers/blob/main/src/transformers/models/modernbert/modeling_modernbert.py
|
| 182 |
+
def flash_attention_forward(qkv: torch.Tensor, rotary_emb: UnpaddedRotaryEmbedding, cu_seqlens: torch.Tensor, max_seqlen: int, causal: bool, local_attention: Tuple[int, int], dropout_p: float, deterministic: bool, target_dtype: torch.dtype = torch.bfloat16, **_kwargs):
|
| 183 |
+
qkv = rotary_emb(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)
|
| 184 |
+
|
| 185 |
+
convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16)
|
| 186 |
+
if convert_dtype:
|
| 187 |
+
# FA2 implementation only supports fp16 and bf16. If FA2 is supported,
|
| 188 |
+
# bfloat16 must be supported as of FA2 2.5.7. (Turing GPUs not supported)
|
| 189 |
+
orig_dtype = qkv.dtype
|
| 190 |
+
qkv = qkv.to(target_dtype)
|
| 191 |
+
|
| 192 |
+
attn = flash_attn_varlen_qkvpacked_func(
|
| 193 |
+
qkv,
|
| 194 |
+
cu_seqlens=cu_seqlens,
|
| 195 |
+
max_seqlen=max_seqlen,
|
| 196 |
+
dropout_p=dropout_p,
|
| 197 |
+
deterministic=deterministic,
|
| 198 |
+
window_size=local_attention,
|
| 199 |
+
causal=False
|
| 200 |
+
)
|
| 201 |
+
attn = attn.to(orig_dtype) # type: ignore
|
| 202 |
+
else:
|
| 203 |
+
attn = flash_attn_varlen_qkvpacked_func(
|
| 204 |
+
qkv,
|
| 205 |
+
cu_seqlens=cu_seqlens,
|
| 206 |
+
max_seqlen=max_seqlen,
|
| 207 |
+
dropout_p=dropout_p,
|
| 208 |
+
deterministic=deterministic,
|
| 209 |
+
window_size=local_attention,
|
| 210 |
+
causal=False
|
| 211 |
+
)
|
| 212 |
+
return attn
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
# from https://github.com/huggingface/transformers/blob/main/src/transformers/models/modernbert/modeling_modernbert.py
|
| 216 |
+
class ApplyRotaryEmbUnpad(torch.autograd.Function):
|
| 217 |
+
@staticmethod
|
| 218 |
+
def forward(ctx, qkv, cos, sin, cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None):
|
| 219 |
+
# (total_nnz, 3, nheads, headdim)
|
| 220 |
+
qkv = qkv.contiguous()
|
| 221 |
+
total_nnz, _three, _nheads, headdim = qkv.shape
|
| 222 |
+
# We need qkv to be contiguous so that when we reshape to combine (3, nheads) dimensions,
|
| 223 |
+
# we get the same tensor
|
| 224 |
+
# qk = rearrange(qkv[:, :2], "b_s t h d -> b_s (t h) d")
|
| 225 |
+
qk = qkv[:, :2].view(total_nnz, -1, headdim)
|
| 226 |
+
apply_rotary(qk, cos, sin, seqlen_offsets=0, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, interleaved=False, inplace=True)
|
| 227 |
+
|
| 228 |
+
ctx.save_for_backward(cos, sin, cu_seqlens)
|
| 229 |
+
ctx.max_seqlen = max_seqlen
|
| 230 |
+
return qkv
|
| 231 |
+
|
| 232 |
+
@staticmethod
|
| 233 |
+
def backward(ctx, do):
|
| 234 |
+
cos, sin, cu_seqlens = ctx.saved_tensors
|
| 235 |
+
do = do.contiguous()
|
| 236 |
+
total_nnz, _three, _nheads, headdim = do.shape
|
| 237 |
+
# We need dqkv to be contiguous so that when we reshape to combine (3, nheads) dimensions,
|
| 238 |
+
# we get the same tensor
|
| 239 |
+
dqk = do[:, :2].view(total_nnz, -1, headdim)
|
| 240 |
+
apply_rotary(
|
| 241 |
+
dqk,
|
| 242 |
+
cos,
|
| 243 |
+
sin,
|
| 244 |
+
seqlen_offsets=0,
|
| 245 |
+
cu_seqlens=cu_seqlens,
|
| 246 |
+
max_seqlen=ctx.max_seqlen,
|
| 247 |
+
interleaved=False,
|
| 248 |
+
inplace=True,
|
| 249 |
+
conjugate=True,
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
return do, None, None, None, None, None, None
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
# from https://github.com/huggingface/transformers/blob/main/src/transformers/models/modernbert/modeling_modernbert.py
|
| 256 |
+
def apply_rotary_unpadded(qkv, cos, sin, cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None):
|
| 257 |
+
return ApplyRotaryEmbUnpad.apply(qkv, cos, sin, cu_seqlens, max_seqlen)
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
# from https://github.com/huggingface/transformers/blob/main/src/transformers/models/modernbert/modeling_modernbert.py
|
| 261 |
+
class UnpaddedRotaryEmbedding(RotaryEmbedding):
|
| 262 |
+
def __init__(self, dim: int, base: float = 10000.0, max_seqlen: Optional[int] = None):
|
| 263 |
+
super().__init__(dim=dim, base=base, device=None, interleaved=False)
|
| 264 |
+
self.max_seqlen = max_seqlen
|
| 265 |
+
|
| 266 |
+
def forward(self, qkv: torch.Tensor, cu_seqlens: torch.Tensor, max_seqlen: Optional[int] = None) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 267 |
+
if max_seqlen is not None:
|
| 268 |
+
self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
|
| 269 |
+
|
| 270 |
+
qkv = apply_rotary_unpadded(
|
| 271 |
+
qkv,
|
| 272 |
+
self._cos_cached,
|
| 273 |
+
self._sin_cached,
|
| 274 |
+
cu_seqlens=cu_seqlens,
|
| 275 |
+
max_seqlen=max_seqlen,
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
return qkv
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
class RotaryPositionalEmbeddings(nn.Module):
|
| 282 |
+
def __init__(self, config, theta: int):
|
| 283 |
+
super().__init__()
|
| 284 |
+
|
| 285 |
+
head_size = config.query_key_head_size
|
| 286 |
+
assert head_size % 2 == 0
|
| 287 |
+
max_seq_len = config.max_sequence_length
|
| 288 |
+
|
| 289 |
+
inv_freq = 1.0 / (theta ** (torch.arange(0, head_size, 2, dtype=torch.float32) / head_size))
|
| 290 |
+
pos = torch.arange(max_seq_len, dtype=torch.float32)
|
| 291 |
+
embedding = torch.einsum('n, d -> nd', pos, inv_freq)
|
| 292 |
+
embedding = torch.cat([embedding, embedding], dim=-1).unsqueeze(0)
|
| 293 |
+
self.register_buffer("cos_matrix", embedding.cos(), persistent=False)
|
| 294 |
+
self.register_buffer("sin_matrix", embedding.sin(), persistent=False)
|
| 295 |
+
|
| 296 |
+
def forward(self, x: torch.Tensor):
|
| 297 |
+
hidden_layer = x.float()
|
| 298 |
+
|
| 299 |
+
seq_len = x.shape[2]
|
| 300 |
+
|
| 301 |
+
cos_matrix = self.cos_matrix[:, None, :seq_len, :]
|
| 302 |
+
sin_matrix = self.sin_matrix[:, None, :seq_len, :]
|
| 303 |
+
|
| 304 |
+
x_rotate_half = torch.cat(
|
| 305 |
+
[
|
| 306 |
+
-hidden_layer[:, :, :, x.size(-1) // 2:],
|
| 307 |
+
hidden_layer[:, :, :, :x.size(-1) // 2]
|
| 308 |
+
],
|
| 309 |
+
dim=-1
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
out = hidden_layer * cos_matrix + x_rotate_half * sin_matrix
|
| 313 |
+
return out.type_as(x)
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
class MaskedSoftmax(torch.autograd.Function):
|
| 317 |
+
@staticmethod
|
| 318 |
+
def forward(ctx, x: torch.Tensor, mask: torch.BoolTensor, dim: int) -> torch.Tensor:
|
| 319 |
+
ctx.dim = dim
|
| 320 |
+
x.masked_fill_(mask, float('-inf'))
|
| 321 |
+
x = torch.softmax(x, ctx.dim)
|
| 322 |
+
x.masked_fill_(mask, 0.0)
|
| 323 |
+
ctx.save_for_backward(x)
|
| 324 |
+
return x
|
| 325 |
+
|
| 326 |
+
@staticmethod
|
| 327 |
+
def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None, None]:
|
| 328 |
+
output: torch.Tensor
|
| 329 |
+
|
| 330 |
+
output, = ctx.saved_tensors
|
| 331 |
+
inputGrad: torch.Tensor = _softmax_backward_data(grad_output, output, ctx.dim, output.dtype)
|
| 332 |
+
return inputGrad, None, None
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
class SelfAttention(nn.Module):
|
| 336 |
+
def __init__(self, config: GptBertConfig, layer_idx: int):
|
| 337 |
+
super().__init__()
|
| 338 |
+
|
| 339 |
+
self.config = config
|
| 340 |
+
self.layer_idx = layer_idx
|
| 341 |
+
|
| 342 |
+
self.d_qk = config.query_key_head_size
|
| 343 |
+
self.d_v = config.value_head_size
|
| 344 |
+
self.num_attention_heads = config.num_attention_heads
|
| 345 |
+
self.num_kv_heads = config.num_attention_heads
|
| 346 |
+
self.hidden_size = config.hidden_size
|
| 347 |
+
|
| 348 |
+
self.q_out_dim = self.d_qk * self.num_attention_heads
|
| 349 |
+
self.k_out_dim = self.d_qk * self.num_kv_heads
|
| 350 |
+
self.v_out_dim = self.d_v * self.num_kv_heads
|
| 351 |
+
|
| 352 |
+
self.qk_proj = MultiCastedLinearOrthoIn(self.hidden_size, [self.q_out_dim, self.k_out_dim], bias=False)
|
| 353 |
+
self.v_proj = CastedLinearIn(self.hidden_size, self.v_out_dim, bias=False)
|
| 354 |
+
self.out_proj = CastedLinearIn(self.d_v*self.num_attention_heads, self.hidden_size, bias=False)
|
| 355 |
+
|
| 356 |
+
self.pre_v_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False)
|
| 357 |
+
self.pre_qk_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False)
|
| 358 |
+
self.inter_norm = nn.LayerNorm(self.d_v * self.num_attention_heads, eps=config.layer_norm_eps, elementwise_affine=False)
|
| 359 |
+
self.q_norm = nn.LayerNorm(self.d_qk, eps=config.layer_norm_eps, elementwise_affine=False, bias=False)
|
| 360 |
+
self.k_norm = nn.LayerNorm(self.d_qk, eps=config.layer_norm_eps, elementwise_affine=False, bias=False)
|
| 361 |
+
self.k_scale = nn.Parameter(torch.ones(self.num_kv_heads, self.d_qk))
|
| 362 |
+
self.q_scale = nn.Parameter(torch.ones(self.num_attention_heads, self.d_qk))
|
| 363 |
+
|
| 364 |
+
self.attention_dropout = nn.Dropout(config.attention_dropout)
|
| 365 |
+
self.dropout = nn.Dropout(config.hidden_dropout)
|
| 366 |
+
|
| 367 |
+
theta = 160_000 if (layer_idx + 1) % config.local_global_ratio == 0 else 10_000
|
| 368 |
+
|
| 369 |
+
# Initialize rotary embeddings based on whether FlashAttention is available
|
| 370 |
+
if flash_attn_varlen_qkvpacked_func is not None:
|
| 371 |
+
self.rope_embedding = UnpaddedRotaryEmbedding(dim=self.d_qk, base=theta, max_seqlen=config.max_sequence_length)
|
| 372 |
+
else:
|
| 373 |
+
self.rope_embedding = RotaryPositionalEmbeddings(config, theta)
|
| 374 |
+
|
| 375 |
+
self.scale = 1.0 / math.sqrt(self.d_qk)
|
| 376 |
+
self.lambdas = nn.Parameter(torch.tensor([0.5]))
|
| 377 |
+
|
| 378 |
+
self.sequence_length = config.max_sequence_length
|
| 379 |
+
self.is_causal = config.is_decoder
|
| 380 |
+
self.window_length = None
|
| 381 |
+
|
| 382 |
+
def set_window_length(self, window_length: int):
|
| 383 |
+
self.window_length = window_length
|
| 384 |
+
|
| 385 |
+
def _get_window_mask(self, query_length: int, key_length: int, device: torch.device):
|
| 386 |
+
"""Create and cache window attention mask."""
|
| 387 |
+
if self.is_causal:
|
| 388 |
+
mask = torch.ones(query_length, key_length, dtype=torch.bool, device=device)
|
| 389 |
+
mask = mask.tril().triu(diagonal=-self.window_length)
|
| 390 |
+
else:
|
| 391 |
+
mask = torch.ones(query_length, key_length, dtype=torch.bool, device=device)
|
| 392 |
+
mask = mask.tril(diagonal=self.window_length).triu(diagonal=-self.window_length)
|
| 393 |
+
return mask.view(1, 1, query_length, key_length)
|
| 394 |
+
|
| 395 |
+
def attention_operation(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, padding_mask: Optional[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 396 |
+
"""Standard attention computation with masking."""
|
| 397 |
+
batch_size, _, query_length, _ = query.size()
|
| 398 |
+
_, _, key_length, _ = key.size()
|
| 399 |
+
|
| 400 |
+
# Use cached window mask
|
| 401 |
+
with torch.no_grad():
|
| 402 |
+
window_mask = self._get_window_mask(query_length, key_length, query.device)
|
| 403 |
+
if padding_mask is not None:
|
| 404 |
+
attention_mask = padding_mask & window_mask
|
| 405 |
+
else:
|
| 406 |
+
attention_mask = window_mask
|
| 407 |
+
|
| 408 |
+
attention_scores = torch.bmm(query.flatten(0, 1), key.transpose(-1, -2).flatten(0, 1)) * self.scale # shape: [B*H, Q_T, K_T]
|
| 409 |
+
attention_scores = attention_scores.view(batch_size, self.num_attention_heads, query_length, key_length)
|
| 410 |
+
|
| 411 |
+
attention_probabilities = MaskedSoftmax.apply(attention_scores, ~attention_mask, -1)
|
| 412 |
+
attention_probabilities = self.attention_dropout(attention_probabilities)
|
| 413 |
+
|
| 414 |
+
output = torch.bmm(attention_probabilities.flatten(0, 1), value.flatten(0, 1))
|
| 415 |
+
output = output.view(batch_size, self.num_attention_heads, query_length, self.d_v)
|
| 416 |
+
|
| 417 |
+
return output
|
| 418 |
+
|
| 419 |
+
def forward(self, hidden_layer: torch.Tensor, qk_layer: torch.Tensor, v1: torch.Tensor | None, padding_info):
|
| 420 |
+
# Get original shape info
|
| 421 |
+
if flash_attn_varlen_qkvpacked_func is not None:
|
| 422 |
+
# Unpadded case
|
| 423 |
+
indices, cu_seqlens, max_seqlen = padding_info
|
| 424 |
+
total_seqlen = hidden_layer.size(0)
|
| 425 |
+
batch_size = cu_seqlens.size(0) - 1
|
| 426 |
+
else:
|
| 427 |
+
# Padded case
|
| 428 |
+
batch_size, seq_length = hidden_layer.size(0), hidden_layer.size(1)
|
| 429 |
+
|
| 430 |
+
hidden_layer = self.pre_v_norm(hidden_layer.float()).type_as(hidden_layer)
|
| 431 |
+
qk_layer = self.pre_qk_norm(qk_layer.float()).type_as(qk_layer)
|
| 432 |
+
|
| 433 |
+
query, key = self.qk_proj(qk_layer).tensor_split([self.q_out_dim], dim=-1)
|
| 434 |
+
value = self.v_proj(hidden_layer)
|
| 435 |
+
|
| 436 |
+
if flash_attn_varlen_qkvpacked_func is not None:
|
| 437 |
+
# Reshape for FlashAttention: (total_seqlen, num_heads, head_dim)
|
| 438 |
+
query = query.view(total_seqlen, self.num_attention_heads, self.d_qk)
|
| 439 |
+
key = key.view(total_seqlen, self.num_kv_heads, self.d_qk)
|
| 440 |
+
value = value.view(total_seqlen, self.num_kv_heads, self.d_v)
|
| 441 |
+
|
| 442 |
+
# Apply layer norm and scaling
|
| 443 |
+
query = ((self.q_scale + 1.0).unsqueeze(0) * self.q_norm(query.float())).type_as(query)
|
| 444 |
+
key = ((self.k_scale + 1.0).unsqueeze(0) * self.k_norm(key.float())).type_as(key)
|
| 445 |
+
|
| 446 |
+
if v1 is None:
|
| 447 |
+
v1 = value
|
| 448 |
+
value = (1 - self.lambdas[0]) * value + self.lambdas[0] * v1
|
| 449 |
+
|
| 450 |
+
# Prepare qkv for FlashAttention
|
| 451 |
+
qkv = torch.stack([query, key, value], dim=1) # (total_seqlen, 3, num_heads, head_dim)
|
| 452 |
+
|
| 453 |
+
# Determine window size for local attention
|
| 454 |
+
if self.window_length is not None and self.window_length > 0:
|
| 455 |
+
if self.is_causal:
|
| 456 |
+
local_attention = (self.window_length - 1, 0)
|
| 457 |
+
else:
|
| 458 |
+
local_attention = (self.window_length - 1, self.window_length - 1)
|
| 459 |
+
else:
|
| 460 |
+
local_attention = (-1, -1)
|
| 461 |
+
|
| 462 |
+
# Apply FlashAttention
|
| 463 |
+
output = flash_attention_forward(
|
| 464 |
+
qkv,
|
| 465 |
+
self.rope_embedding,
|
| 466 |
+
cu_seqlens,
|
| 467 |
+
max_seqlen,
|
| 468 |
+
self.is_causal,
|
| 469 |
+
local_attention,
|
| 470 |
+
self.config.attention_dropout if self.training else 0.0,
|
| 471 |
+
self.config.deterministic_flash_attn
|
| 472 |
+
)
|
| 473 |
+
|
| 474 |
+
# Reshape output back
|
| 475 |
+
output = output.view(total_seqlen, self.d_v * self.num_attention_heads)
|
| 476 |
+
|
| 477 |
+
else:
|
| 478 |
+
# Standard attention path
|
| 479 |
+
query_length = query.size(1)
|
| 480 |
+
key_length = key.size(1)
|
| 481 |
+
|
| 482 |
+
query = query.reshape(batch_size, query_length, self.num_attention_heads, self.d_qk).transpose(1, 2)
|
| 483 |
+
key = key.reshape(batch_size, key_length, self.num_kv_heads, self.d_qk).transpose(1, 2)
|
| 484 |
+
value = value.reshape(batch_size, key_length, self.num_kv_heads, self.d_v).transpose(1, 2)
|
| 485 |
+
|
| 486 |
+
query = ((self.q_scale + 1.0).unsqueeze(1).unsqueeze(0) * self.q_norm(query.float())).type_as(query)
|
| 487 |
+
key = ((self.k_scale + 1.0).unsqueeze(1).unsqueeze(0) * self.k_norm(key.float())).type_as(key)
|
| 488 |
+
|
| 489 |
+
if v1 is None:
|
| 490 |
+
v1 = value
|
| 491 |
+
else:
|
| 492 |
+
value = (1 - self.lambdas[0]) * value + self.lambdas[0] * v1
|
| 493 |
+
|
| 494 |
+
# Apply rotary embeddings
|
| 495 |
+
query = self.rope_embedding(query)
|
| 496 |
+
key = self.rope_embedding(key)
|
| 497 |
+
|
| 498 |
+
output = self.attention_operation(query, key, value, padding_info)
|
| 499 |
+
output = output.transpose(1, 2).flatten(2, 3) # shape: [B, T, H*D]
|
| 500 |
+
|
| 501 |
+
output = self.inter_norm(output.float()).type_as(output)
|
| 502 |
+
output = self.out_proj(output)
|
| 503 |
+
output = self.dropout(output)
|
| 504 |
+
|
| 505 |
+
return output, v1
|
| 506 |
+
|
| 507 |
+
|
| 508 |
+
class FeedForward(nn.Module):
|
| 509 |
+
def __init__(self, config: GptBertConfig):
|
| 510 |
+
super().__init__()
|
| 511 |
+
self.pre_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False)
|
| 512 |
+
self.up_proj = MultiCastedLinearOrthoIn(config.hidden_size, [config.intermediate_size, config.intermediate_size], bias=False)
|
| 513 |
+
self.activation = GeGLU()
|
| 514 |
+
self.inter_norm = nn.LayerNorm(config.intermediate_size, eps=config.layer_norm_eps, elementwise_affine=False)
|
| 515 |
+
self.down_proj = CastedLinearIn(config.intermediate_size, config.hidden_size, bias=False)
|
| 516 |
+
self.dropout = nn.Dropout(config.hidden_dropout)
|
| 517 |
+
|
| 518 |
+
def forward(self, x: torch.Tensor):
|
| 519 |
+
x = self.pre_norm(x.float()).type_as(x)
|
| 520 |
+
x = self.up_proj(x)
|
| 521 |
+
x = self.activation(x)
|
| 522 |
+
x = self.inter_norm(x.float()).type_as(x)
|
| 523 |
+
x = self.down_proj(x)
|
| 524 |
+
x = self.dropout(x)
|
| 525 |
+
return x
|
| 526 |
+
|
| 527 |
+
|
| 528 |
+
class Layer(nn.Module):
|
| 529 |
+
def __init__(self, config: GptBertConfig, layer_idx: int):
|
| 530 |
+
super().__init__()
|
| 531 |
+
|
| 532 |
+
self.attention = SelfAttention(config, layer_idx)
|
| 533 |
+
self.mlp = FeedForward(config)
|
| 534 |
+
self.lambdas = nn.Parameter(torch.tensor([0., 0., 1., 0., 1., 0.]))
|
| 535 |
+
|
| 536 |
+
def set_window_length(self, window_length: int):
|
| 537 |
+
self.attention.set_window_length(window_length)
|
| 538 |
+
|
| 539 |
+
def forward(self, hidden_layer: torch.Tensor, embeddings: torch.Tensor, v1: torch.Tensor | None, padding_info):
|
| 540 |
+
attention_output = (1 - self.lambdas[0]) * hidden_layer + self.lambdas[0] * embeddings
|
| 541 |
+
qk_layer = (1 - self.lambdas[1]) * hidden_layer + self.lambdas[1] * embeddings
|
| 542 |
+
mlp_layer = F.softplus(self.lambdas[2]) * ((1 - self.lambdas[3]) * hidden_layer + self.lambdas[3] * embeddings)
|
| 543 |
+
|
| 544 |
+
attention_output, v1 = self.attention(attention_output, qk_layer, v1, padding_info)
|
| 545 |
+
mlp_layer = mlp_layer + attention_output
|
| 546 |
+
hidden_layer = F.softplus(self.lambdas[4]) * ((1 - self.lambdas[5]) * hidden_layer + self.lambdas[5] * embeddings)
|
| 547 |
+
output = hidden_layer + attention_output + self.mlp(mlp_layer)
|
| 548 |
+
|
| 549 |
+
return output, v1
|
| 550 |
+
|
| 551 |
+
|
| 552 |
+
class Encoder(nn.Module):
|
| 553 |
+
def __init__(self, config: GptBertConfig):
|
| 554 |
+
super().__init__()
|
| 555 |
+
self.layers = nn.ModuleList([Layer(config, i) for i in range(config.num_layers)])
|
| 556 |
+
self.local_global_ratio = config.local_global_ratio
|
| 557 |
+
|
| 558 |
+
def set_window_length(self, config: GptBertConfig):
|
| 559 |
+
for i, layer in enumerate(self.layers):
|
| 560 |
+
if (i + 1) % self.local_global_ratio == 0:
|
| 561 |
+
layer.set_window_length(config.global_window_length)
|
| 562 |
+
else:
|
| 563 |
+
layer.set_window_length(config.local_window_length)
|
| 564 |
+
|
| 565 |
+
def forward(self, hidden_layer: torch.Tensor, padding_info, output_hidden_states=False, checkpoint_activations=False):
|
| 566 |
+
hidden_layers = [hidden_layer] if output_hidden_states else None
|
| 567 |
+
v1 = None
|
| 568 |
+
embeddings = hidden_layer
|
| 569 |
+
|
| 570 |
+
for layer in self.layers:
|
| 571 |
+
if checkpoint_activations:
|
| 572 |
+
hidden_layer, v1 = torch.utils.checkpoint.checkpoint(layer, hidden_layer, embeddings, v1, padding_info, use_reentrant=True)
|
| 573 |
+
else:
|
| 574 |
+
hidden_layer, v1 = layer(hidden_layer, embeddings, v1, padding_info)
|
| 575 |
+
|
| 576 |
+
if output_hidden_states:
|
| 577 |
+
hidden_layers.append(hidden_layer)
|
| 578 |
+
|
| 579 |
+
return hidden_layer, hidden_layers
|
| 580 |
+
|
| 581 |
+
|
| 582 |
+
#
|
| 583 |
+
# HuggingFace wrappers
|
| 584 |
+
#
|
| 585 |
+
|
| 586 |
+
class GptBertPreTrainedModel(PreTrainedModel):
|
| 587 |
+
config_class = GptBertConfig
|
| 588 |
+
supports_gradient_checkpointing = True
|
| 589 |
+
_supports_flash_attn_2 = True
|
| 590 |
+
_supports_sdpa = True
|
| 591 |
+
_supports_flex_attn = False
|
| 592 |
+
|
| 593 |
+
def _init_weights(self, module):
|
| 594 |
+
std = math.sqrt(2.0 / (5.0 * self.hidden_size))
|
| 595 |
+
|
| 596 |
+
if isinstance(module, nn.Linear) or isinstance(module, CastedLinearIn):
|
| 597 |
+
nn.init.trunc_normal_(module.weight.data, mean=0.0, std=std, a=-2*std, b=2*std)
|
| 598 |
+
if module.bias is not None:
|
| 599 |
+
module.bias.data.zero_()
|
| 600 |
+
elif isinstance(module, nn.Embedding):
|
| 601 |
+
nn.init.trunc_normal_(module.weight.data, mean=0.0, std=std, a=-2*std, b=2*std)
|
| 602 |
+
elif isinstance(module, nn.LayerNorm):
|
| 603 |
+
module.bias.data.zero_()
|
| 604 |
+
module.weight.data.fill_(1.0)
|
| 605 |
+
|
| 606 |
+
|
| 607 |
+
class GptBertModel(GptBertPreTrainedModel):
|
| 608 |
+
def __init__(self, config: GptBertConfig, add_mlm_layer=False, **kwargs):
|
| 609 |
+
super().__init__(config, **kwargs)
|
| 610 |
+
self.config = config
|
| 611 |
+
self.hidden_size = config.hidden_size
|
| 612 |
+
|
| 613 |
+
self.embedding = Embedding(config)
|
| 614 |
+
self.encoder = Encoder(config)
|
| 615 |
+
self.classifier = LMClassifier(config, config.vocab_size) if add_mlm_layer else None
|
| 616 |
+
self.set_window_length(config)
|
| 617 |
+
self.gradient_checkpointing = False
|
| 618 |
+
self.post_init()
|
| 619 |
+
|
| 620 |
+
def set_window_length(self, config) -> None:
|
| 621 |
+
self.encoder.set_window_length(config)
|
| 622 |
+
|
| 623 |
+
def get_input_embeddings(self):
|
| 624 |
+
return self.embedding.word_embedding
|
| 625 |
+
|
| 626 |
+
def set_input_embeddings(self, value):
|
| 627 |
+
self.embedding.word_embedding = value
|
| 628 |
+
|
| 629 |
+
def get_contextualized_embeddings(
|
| 630 |
+
self,
|
| 631 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 632 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 633 |
+
output_hidden_states: Optional[bool] = None
|
| 634 |
+
):
|
| 635 |
+
if input_ids is not None:
|
| 636 |
+
input_shape = input_ids.size()
|
| 637 |
+
else:
|
| 638 |
+
raise ValueError("You have to specify input_ids")
|
| 639 |
+
|
| 640 |
+
batch_size, seq_length = input_shape
|
| 641 |
+
device = input_ids.device
|
| 642 |
+
|
| 643 |
+
if attention_mask is None:
|
| 644 |
+
attention_mask = torch.ones(batch_size, seq_length, dtype=torch.bool, device=device)
|
| 645 |
+
else:
|
| 646 |
+
attention_mask = attention_mask.bool()
|
| 647 |
+
|
| 648 |
+
if flash_attn_varlen_qkvpacked_func is not None:
|
| 649 |
+
if len(attention_mask.size()) != 2:
|
| 650 |
+
raise ValueError("Bare `attention_mask` med to dimensjoner støttes nå for FlashAttention.")
|
| 651 |
+
with torch.no_grad():
|
| 652 |
+
input_ids, indices, cu_seqlens, max_seqlen_in_batch = _unpad_input(input_ids, attention_mask)
|
| 653 |
+
padding_info = (indices, cu_seqlens, max_seqlen_in_batch)
|
| 654 |
+
else:
|
| 655 |
+
if len(attention_mask.size()) == 2:
|
| 656 |
+
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
| 657 |
+
elif len(attention_mask.size()) == 3:
|
| 658 |
+
attention_mask = attention_mask.unsqueeze(1)
|
| 659 |
+
padding_info = attention_mask
|
| 660 |
+
|
| 661 |
+
static_embeddings = self.embedding(input_ids)
|
| 662 |
+
|
| 663 |
+
original_dtype = static_embeddings.dtype
|
| 664 |
+
if torch.cuda.is_available() and torch.cuda.is_bf16_supported() and static_embeddings.dtype == torch.float32:
|
| 665 |
+
static_embeddings = static_embeddings.bfloat16()
|
| 666 |
+
|
| 667 |
+
last_layer, contextualized_embeddings = self.encoder(
|
| 668 |
+
static_embeddings,
|
| 669 |
+
padding_info,
|
| 670 |
+
output_hidden_states=output_hidden_states,
|
| 671 |
+
checkpoint_activations=self.gradient_checkpointing and self.training
|
| 672 |
+
)
|
| 673 |
+
|
| 674 |
+
last_layer = last_layer.to(original_dtype)
|
| 675 |
+
if output_hidden_states:
|
| 676 |
+
contextualized_embeddings = [layer.to(original_dtype) for layer in contextualized_embeddings]
|
| 677 |
+
|
| 678 |
+
# Pad output if using FlashAttention
|
| 679 |
+
if flash_attn_varlen_qkvpacked_func is not None:
|
| 680 |
+
last_layer = _pad_output(last_layer, indices, batch_size, seq_length)
|
| 681 |
+
if output_hidden_states:
|
| 682 |
+
contextualized_embeddings = [_pad_output(layer, indices, batch_size, seq_length) for layer in contextualized_embeddings]
|
| 683 |
+
else:
|
| 684 |
+
contextualized_embeddings = None
|
| 685 |
+
|
| 686 |
+
return last_layer, contextualized_embeddings
|
| 687 |
+
|
| 688 |
+
def forward(
|
| 689 |
+
self,
|
| 690 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 691 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 692 |
+
output_hidden_states: Optional[bool] = None,
|
| 693 |
+
output_attentions: Optional[bool] = None,
|
| 694 |
+
return_dict: Optional[bool] = None,
|
| 695 |
+
**kwargs
|
| 696 |
+
) -> Union[Tuple[torch.Tensor], BaseModelOutput]:
|
| 697 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 698 |
+
|
| 699 |
+
sequence_output, contextualized_embeddings = self.get_contextualized_embeddings(input_ids, attention_mask, output_hidden_states)
|
| 700 |
+
|
| 701 |
+
if not return_dict:
|
| 702 |
+
return (
|
| 703 |
+
sequence_output,
|
| 704 |
+
*([contextualized_embeddings] if output_hidden_states else [])
|
| 705 |
+
)
|
| 706 |
+
|
| 707 |
+
return BaseModelOutput(
|
| 708 |
+
last_hidden_state=sequence_output,
|
| 709 |
+
hidden_states=contextualized_embeddings if output_hidden_states else None
|
| 710 |
+
)
|
| 711 |
+
|
| 712 |
+
|
| 713 |
+
class GptBertForMaskedLM(GptBertModel):
|
| 714 |
+
_tied_weights_keys = ["classifier.emb2vocab.weight"]
|
| 715 |
+
|
| 716 |
+
def __init__(self, config: GptBertConfig, **kwargs):
|
| 717 |
+
super().__init__(config, add_mlm_layer=True, **kwargs)
|
| 718 |
+
|
| 719 |
+
def get_output_embeddings(self):
|
| 720 |
+
return self.classifier.emb2vocab.weight
|
| 721 |
+
|
| 722 |
+
def set_output_embeddings(self, new_embeddings):
|
| 723 |
+
self.classifier.emb2vocab.weight = new_embeddings
|
| 724 |
+
|
| 725 |
+
def forward(
|
| 726 |
+
self,
|
| 727 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 728 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 729 |
+
output_hidden_states: Optional[bool] = None,
|
| 730 |
+
return_dict: Optional[bool] = None,
|
| 731 |
+
labels: Optional[torch.LongTensor] = None,
|
| 732 |
+
**kwargs
|
| 733 |
+
) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
|
| 734 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 735 |
+
|
| 736 |
+
sequence_output, contextualized_embeddings = self.get_contextualized_embeddings(input_ids, attention_mask, output_hidden_states)
|
| 737 |
+
subword_prediction = self.classifier(sequence_output)
|
| 738 |
+
subword_prediction = 30 * torch.sigmoid(subword_prediction / 7.5)
|
| 739 |
+
|
| 740 |
+
masked_lm_loss = None
|
| 741 |
+
if labels is not None:
|
| 742 |
+
labels_flatten = labels[:, 1:].flatten()
|
| 743 |
+
subword_prediction_flatten = subword_prediction[:, :-1].flatten(0, 1)
|
| 744 |
+
masked_lm_loss = F.cross_entropy(subword_prediction_flatten, labels_flatten)
|
| 745 |
+
|
| 746 |
+
bos_logits = torch.zeros(subword_prediction.size(0), 1, self.config.vocab_size, dtype=subword_prediction.dtype, device=subword_prediction.device)
|
| 747 |
+
bos_logits[:, :, self.config.bos_token_id] = 1.0
|
| 748 |
+
subword_prediction = torch.cat([bos_logits, subword_prediction[:, :-1]], dim=1)
|
| 749 |
+
|
| 750 |
+
if not return_dict:
|
| 751 |
+
output = (
|
| 752 |
+
subword_prediction,
|
| 753 |
+
*([contextualized_embeddings] if output_hidden_states else [])
|
| 754 |
+
)
|
| 755 |
+
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
|
| 756 |
+
|
| 757 |
+
return MaskedLMOutput(
|
| 758 |
+
loss=masked_lm_loss,
|
| 759 |
+
logits=subword_prediction,
|
| 760 |
+
hidden_states=contextualized_embeddings if output_hidden_states else None
|
| 761 |
+
)
|
| 762 |
+
|
| 763 |
+
|
| 764 |
+
class GptBertForCausalLM(GptBertModel):
|
| 765 |
+
_tied_weights_keys = ["classifier.emb2vocab.weight"]
|
| 766 |
+
|
| 767 |
+
def __init__(self, config: GptBertConfig, **kwargs):
|
| 768 |
+
config.is_decoder = True
|
| 769 |
+
super().__init__(config, add_mlm_layer=True, **kwargs)
|
| 770 |
+
|
| 771 |
+
def get_output_embeddings(self):
|
| 772 |
+
return self.classifier.emb2vocab.weight
|
| 773 |
+
|
| 774 |
+
def set_output_embeddings(self, new_embeddings):
|
| 775 |
+
self.classifier.emb2vocab.weight = new_embeddings
|
| 776 |
+
|
| 777 |
+
def get_input_embeddings(self):
|
| 778 |
+
return self.embedding.word_embedding
|
| 779 |
+
|
| 780 |
+
def set_input_embeddings(self, value):
|
| 781 |
+
self.embedding.word_embedding = value
|
| 782 |
+
|
| 783 |
+
def set_decoder(self, decoder):
|
| 784 |
+
self.encoder = decoder
|
| 785 |
+
|
| 786 |
+
def get_decoder(self):
|
| 787 |
+
return self.encoder
|
| 788 |
+
|
| 789 |
+
def can_generate(self):
|
| 790 |
+
return True
|
| 791 |
+
|
| 792 |
+
def forward(
|
| 793 |
+
self,
|
| 794 |
+
input_ids: torch.LongTensor = None,
|
| 795 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 796 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 797 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
| 798 |
+
past_key_values: Optional[torch.Tensor] = None,
|
| 799 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 800 |
+
labels: Optional[torch.LongTensor] = None,
|
| 801 |
+
use_cache: Optional[bool] = None,
|
| 802 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 803 |
+
output_attentions: Optional[bool] = None,
|
| 804 |
+
output_hidden_states: Optional[bool] = None,
|
| 805 |
+
return_dict: Optional[bool] = None
|
| 806 |
+
) -> Union[Tuple, CausalLMOutput]:
|
| 807 |
+
|
| 808 |
+
assert inputs_embeds is None, "inputs_embeds is not supported for now"
|
| 809 |
+
assert past_key_values is None, "past_key_values is not supported for now"
|
| 810 |
+
assert not use_cache, "use_cache is not supported for now"
|
| 811 |
+
|
| 812 |
+
sequence_output, contextualized_embeddings = self.get_contextualized_embeddings(input_ids, attention_mask, output_hidden_states)
|
| 813 |
+
subword_prediction = self.classifier(sequence_output)
|
| 814 |
+
subword_prediction = 30 * torch.sigmoid(subword_prediction / 7.5)
|
| 815 |
+
|
| 816 |
+
causal_lm_loss = None
|
| 817 |
+
if labels is not None:
|
| 818 |
+
labels_flatten = labels[:, 1:].flatten()
|
| 819 |
+
subword_prediction_flatten = subword_prediction[:, :-1].flatten(0, 1)
|
| 820 |
+
causal_lm_loss = F.cross_entropy(subword_prediction_flatten, labels_flatten)
|
| 821 |
+
|
| 822 |
+
if not return_dict:
|
| 823 |
+
output = (
|
| 824 |
+
subword_prediction,
|
| 825 |
+
*([contextualized_embeddings] if output_hidden_states else [])
|
| 826 |
+
)
|
| 827 |
+
return ((causal_lm_loss,) + output) if masked_lm_loss is not None else output
|
| 828 |
+
|
| 829 |
+
return CausalLMOutput(
|
| 830 |
+
loss=causal_lm_loss,
|
| 831 |
+
logits=subword_prediction,
|
| 832 |
+
hidden_states=contextualized_embeddings if output_hidden_states else None
|
| 833 |
+
)
|
| 834 |
+
|
| 835 |
+
def prepare_inputs_for_generation(
|
| 836 |
+
self,
|
| 837 |
+
input_ids: torch.Tensor,
|
| 838 |
+
past_key_values: Optional[torch.Tensor] = None,
|
| 839 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 840 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 841 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 842 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 843 |
+
use_cache: bool = True,
|
| 844 |
+
num_logits_to_keep: Optional[int] = None,
|
| 845 |
+
**kwargs,
|
| 846 |
+
):
|
| 847 |
+
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
| 848 |
+
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
| 849 |
+
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
| 850 |
+
if past_key_values is not None:
|
| 851 |
+
if inputs_embeds is not None: # Exception 1
|
| 852 |
+
input_ids = input_ids[:, -cache_position.shape[0] :]
|
| 853 |
+
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
| 854 |
+
input_ids = input_ids[:, cache_position]
|
| 855 |
+
|
| 856 |
+
if attention_mask is not None and position_ids is None:
|
| 857 |
+
# create position_ids on the fly for batch generation
|
| 858 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
| 859 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
| 860 |
+
if past_key_values:
|
| 861 |
+
position_ids = position_ids[:, -input_ids.shape[1] :]
|
| 862 |
+
|
| 863 |
+
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
|
| 864 |
+
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
|
| 865 |
+
|
| 866 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
| 867 |
+
if inputs_embeds is not None and cache_position[0] == 0:
|
| 868 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
| 869 |
+
else:
|
| 870 |
+
model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
|
| 871 |
+
|
| 872 |
+
if num_logits_to_keep is not None:
|
| 873 |
+
model_inputs["num_logits_to_keep"] = num_logits_to_keep
|
| 874 |
+
|
| 875 |
+
model_inputs.update(
|
| 876 |
+
{
|
| 877 |
+
"position_ids": position_ids,
|
| 878 |
+
"cache_position": cache_position,
|
| 879 |
+
"past_key_values": past_key_values,
|
| 880 |
+
"use_cache": use_cache,
|
| 881 |
+
"attention_mask": attention_mask,
|
| 882 |
+
}
|
| 883 |
+
)
|
| 884 |
+
return model_inputs
|
| 885 |
+
|
| 886 |
+
|
| 887 |
+
class GptBertForSequenceClassification(GptBertModel):
|
| 888 |
+
_keys_to_ignore_on_load_missing = ["classifier.emb2vocab.weight", "classifier.emb2vocab.bias"]
|
| 889 |
+
_keys_to_ignore_on_load_unexpected = ["classifier.emb2vocab.weight", "classifier.emb2vocab.bias"]
|
| 890 |
+
|
| 891 |
+
def __init__(self, config: GptBertConfig, **kwargs):
|
| 892 |
+
super().__init__(config, add_mlm_layer=False, **kwargs)
|
| 893 |
+
|
| 894 |
+
self.num_labels = config.num_labels
|
| 895 |
+
self.classifier = Classifier(config, self.num_labels)
|
| 896 |
+
self.post_init()
|
| 897 |
+
|
| 898 |
+
def forward(
|
| 899 |
+
self,
|
| 900 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 901 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 902 |
+
output_hidden_states: Optional[bool] = None,
|
| 903 |
+
return_dict: Optional[bool] = None,
|
| 904 |
+
labels: Optional[torch.LongTensor] = None,
|
| 905 |
+
**kwargs
|
| 906 |
+
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
|
| 907 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 908 |
+
|
| 909 |
+
sequence_output, contextualized_embeddings = self.get_contextualized_embeddings(input_ids, attention_mask, output_hidden_states)
|
| 910 |
+
logits = self.classifier(sequence_output[:, 0, :])
|
| 911 |
+
|
| 912 |
+
loss = None
|
| 913 |
+
if labels is not None:
|
| 914 |
+
if self.config.problem_type is None:
|
| 915 |
+
if self.num_labels == 1:
|
| 916 |
+
self.config.problem_type = "regression"
|
| 917 |
+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
| 918 |
+
self.config.problem_type = "single_label_classification"
|
| 919 |
+
else:
|
| 920 |
+
self.config.problem_type = "multi_label_classification"
|
| 921 |
+
|
| 922 |
+
if self.config.problem_type == "regression":
|
| 923 |
+
loss_fct = nn.MSELoss()
|
| 924 |
+
if self.num_labels == 1:
|
| 925 |
+
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
| 926 |
+
else:
|
| 927 |
+
loss = loss_fct(logits, labels)
|
| 928 |
+
elif self.config.problem_type == "single_label_classification":
|
| 929 |
+
loss_fct = nn.CrossEntropyLoss()
|
| 930 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 931 |
+
elif self.config.problem_type == "multi_label_classification":
|
| 932 |
+
loss_fct = nn.BCEWithLogitsLoss()
|
| 933 |
+
loss = loss_fct(logits, labels)
|
| 934 |
+
|
| 935 |
+
if not return_dict:
|
| 936 |
+
output = (
|
| 937 |
+
logits,
|
| 938 |
+
*([contextualized_embeddings] if output_hidden_states else [])
|
| 939 |
+
)
|
| 940 |
+
return ((loss,) + output) if loss is not None else output
|
| 941 |
+
|
| 942 |
+
return SequenceClassifierOutput(
|
| 943 |
+
loss=loss,
|
| 944 |
+
logits=logits,
|
| 945 |
+
hidden_states=contextualized_embeddings if output_hidden_states else None
|
| 946 |
+
)
|
| 947 |
+
|
| 948 |
+
|
| 949 |
+
class GptBertForTokenClassification(GptBertModel):
|
| 950 |
+
_keys_to_ignore_on_load_missing = ["classifier.emb2vocab.weight", "classifier.emb2vocab.bias"]
|
| 951 |
+
_keys_to_ignore_on_load_unexpected = ["classifier.emb2vocab.weight", "classifier.emb2vocab.bias"]
|
| 952 |
+
|
| 953 |
+
def __init__(self, config: GptBertConfig, **kwargs):
|
| 954 |
+
super().__init__(config, add_mlm_layer=False, **kwargs)
|
| 955 |
+
|
| 956 |
+
self.num_labels = config.num_labels
|
| 957 |
+
self.classifier = Classifier(config, self.num_labels)
|
| 958 |
+
self.post_init()
|
| 959 |
+
|
| 960 |
+
def forward(
|
| 961 |
+
self,
|
| 962 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 963 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 964 |
+
output_hidden_states: Optional[bool] = None,
|
| 965 |
+
return_dict: Optional[bool] = None,
|
| 966 |
+
labels: Optional[torch.LongTensor] = None,
|
| 967 |
+
**kwargs
|
| 968 |
+
) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
|
| 969 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 970 |
+
|
| 971 |
+
sequence_output, contextualized_embeddings = self.get_contextualized_embeddings(input_ids, attention_mask, output_hidden_states)
|
| 972 |
+
logits = self.classifier(sequence_output)
|
| 973 |
+
|
| 974 |
+
loss = None
|
| 975 |
+
if labels is not None:
|
| 976 |
+
loss_fct = nn.CrossEntropyLoss()
|
| 977 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 978 |
+
|
| 979 |
+
if not return_dict:
|
| 980 |
+
output = (
|
| 981 |
+
logits,
|
| 982 |
+
*([contextualized_embeddings] if output_hidden_states else []),
|
| 983 |
+
*([attention_probs] if output_attentions else [])
|
| 984 |
+
)
|
| 985 |
+
return ((loss,) + output) if loss is not None else output
|
| 986 |
+
|
| 987 |
+
return TokenClassifierOutput(
|
| 988 |
+
loss=loss,
|
| 989 |
+
logits=logits,
|
| 990 |
+
hidden_states=contextualized_embeddings if output_hidden_states else None,
|
| 991 |
+
attentions=attention_probs if output_attentions else None
|
| 992 |
+
)
|
| 993 |
+
|
| 994 |
+
|
| 995 |
+
class GptBertForQuestionAnswering(GptBertModel):
|
| 996 |
+
_keys_to_ignore_on_load_missing = ["classifier.emb2vocab.weight", "classifier.emb2vocab.bias"]
|
| 997 |
+
_keys_to_ignore_on_load_unexpected = ["classifier.emb2vocab.weight", "classifier.emb2vocab.bias"]
|
| 998 |
+
|
| 999 |
+
def __init__(self, config: GptBertConfig, **kwargs):
|
| 1000 |
+
super().__init__(config, add_mlm_layer=False, **kwargs)
|
| 1001 |
+
|
| 1002 |
+
self.num_labels = config.num_labels
|
| 1003 |
+
self.classifier = Classifier(config, self.num_labels)
|
| 1004 |
+
self.post_init()
|
| 1005 |
+
|
| 1006 |
+
def forward(
|
| 1007 |
+
self,
|
| 1008 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 1009 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1010 |
+
output_hidden_states: Optional[bool] = None,
|
| 1011 |
+
return_dict: Optional[bool] = None,
|
| 1012 |
+
start_positions: Optional[torch.Tensor] = None,
|
| 1013 |
+
end_positions: Optional[torch.Tensor] = None,
|
| 1014 |
+
**kwargs
|
| 1015 |
+
) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
|
| 1016 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1017 |
+
|
| 1018 |
+
sequence_output, contextualized_embeddings = self.get_contextualized_embeddings(input_ids, attention_mask, output_hidden_states)
|
| 1019 |
+
logits = self.classifier(sequence_output)
|
| 1020 |
+
|
| 1021 |
+
start_logits, end_logits = logits.split(1, dim=-1)
|
| 1022 |
+
start_logits = start_logits.squeeze(-1).contiguous()
|
| 1023 |
+
end_logits = end_logits.squeeze(-1).contiguous()
|
| 1024 |
+
|
| 1025 |
+
total_loss = None
|
| 1026 |
+
if start_positions is not None and end_positions is not None:
|
| 1027 |
+
# If we are on multi-GPU, split add a dimension
|
| 1028 |
+
if len(start_positions.size()) > 1:
|
| 1029 |
+
start_positions = start_positions.squeeze(-1)
|
| 1030 |
+
if len(end_positions.size()) > 1:
|
| 1031 |
+
end_positions = end_positions.squeeze(-1)
|
| 1032 |
+
|
| 1033 |
+
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
| 1034 |
+
ignored_index = start_logits.size(1)
|
| 1035 |
+
start_positions = start_positions.clamp(0, ignored_index)
|
| 1036 |
+
end_positions = end_positions.clamp(0, ignored_index)
|
| 1037 |
+
|
| 1038 |
+
loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index)
|
| 1039 |
+
start_loss = loss_fct(start_logits, start_positions)
|
| 1040 |
+
end_loss = loss_fct(end_logits, end_positions)
|
| 1041 |
+
total_loss = (start_loss + end_loss) / 2
|
| 1042 |
+
|
| 1043 |
+
if not return_dict:
|
| 1044 |
+
output = (
|
| 1045 |
+
start_logits,
|
| 1046 |
+
end_logits,
|
| 1047 |
+
*([contextualized_embeddings] if output_hidden_states else [])
|
| 1048 |
+
)
|
| 1049 |
+
return ((total_loss,) + output) if total_loss is not None else output
|
| 1050 |
+
|
| 1051 |
+
return QuestionAnsweringModelOutput(
|
| 1052 |
+
loss=total_loss,
|
| 1053 |
+
start_logits=start_logits,
|
| 1054 |
+
end_logits=end_logits,
|
| 1055 |
+
hidden_states=contextualized_embeddings if output_hidden_states else None
|
| 1056 |
+
)
|
| 1057 |
+
|
| 1058 |
+
|
| 1059 |
+
class GptBertForMultipleChoice(GptBertModel):
|
| 1060 |
+
_keys_to_ignore_on_load_missing = ["classifier.emb2vocab.weight", "classifier.emb2vocab.bias"]
|
| 1061 |
+
_keys_to_ignore_on_load_unexpected = ["classifier.emb2vocab.weight", "classifier.emb2vocab.bias"]
|
| 1062 |
+
|
| 1063 |
+
def __init__(self, config: GptBertConfig, **kwargs):
|
| 1064 |
+
super().__init__(config, add_mlm_layer=False, **kwargs)
|
| 1065 |
+
|
| 1066 |
+
self.num_labels = getattr(config, "num_labels", 2)
|
| 1067 |
+
self.classifier = Classifier(config, self.num_labels)
|
| 1068 |
+
self.post_init()
|
| 1069 |
+
|
| 1070 |
+
def forward(
|
| 1071 |
+
self,
|
| 1072 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 1073 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1074 |
+
labels: Optional[torch.Tensor] = None,
|
| 1075 |
+
output_hidden_states: Optional[bool] = None,
|
| 1076 |
+
return_dict: Optional[bool] = None,
|
| 1077 |
+
**kwargs
|
| 1078 |
+
) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]:
|
| 1079 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1080 |
+
num_choices = input_ids.shape[1]
|
| 1081 |
+
|
| 1082 |
+
flat_input_ids = input_ids.view(-1, input_ids.size(-1))
|
| 1083 |
+
flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
|
| 1084 |
+
|
| 1085 |
+
sequence_output, contextualized_embeddings = self.get_contextualized_embeddings(flat_input_ids, flat_attention_mask, output_hidden_states)
|
| 1086 |
+
logits = self.classifier(sequence_output)
|
| 1087 |
+
reshaped_logits = logits.view(-1, num_choices)
|
| 1088 |
+
|
| 1089 |
+
loss = None
|
| 1090 |
+
if labels is not None:
|
| 1091 |
+
loss_fct = nn.CrossEntropyLoss()
|
| 1092 |
+
loss = loss_fct(reshaped_logits, labels)
|
| 1093 |
+
|
| 1094 |
+
if not return_dict:
|
| 1095 |
+
output = (
|
| 1096 |
+
reshaped_logits,
|
| 1097 |
+
*([contextualized_embeddings] if output_hidden_states else [])
|
| 1098 |
+
)
|
| 1099 |
+
return ((loss,) + output) if loss is not None else output
|
| 1100 |
+
|
| 1101 |
+
return MultipleChoiceModelOutput(
|
| 1102 |
+
loss=loss,
|
| 1103 |
+
logits=reshaped_logits,
|
| 1104 |
+
hidden_states=contextualized_embeddings if output_hidden_states else None
|
| 1105 |
+
)
|
last-checkpoint/modules.json
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{
|
| 3 |
+
"idx": 0,
|
| 4 |
+
"name": "0",
|
| 5 |
+
"path": "",
|
| 6 |
+
"type": "sentence_transformers.sparse_encoder.models.MLMTransformer"
|
| 7 |
+
},
|
| 8 |
+
{
|
| 9 |
+
"idx": 1,
|
| 10 |
+
"name": "1",
|
| 11 |
+
"path": "1_SpladePooling",
|
| 12 |
+
"type": "sentence_transformers.sparse_encoder.models.SpladePooling"
|
| 13 |
+
}
|
| 14 |
+
]
|
last-checkpoint/optimizer.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6ee91d450dde2beb96e8c2398912baa13f858a3b4f7cee07b28a9b96f3e588ef
|
| 3 |
+
size 1457369077
|
last-checkpoint/rng_state_0.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4cdc784e3b91bc23bce54961fdaef58e6442cd03f625edf44e230178fd37f8fa
|
| 3 |
+
size 14917
|
last-checkpoint/rng_state_1.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a36844f32afb06c561965a6f6eb81809058336e154b9d7e2fb6b83900a7ad0fa
|
| 3 |
+
size 14917
|
last-checkpoint/scheduler.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d959e23fadc9c5a5f14d7b9c3a56d1fd374b1bcf4b39d4b142f83de164ff2685
|
| 3 |
+
size 1465
|
last-checkpoint/sentence_bert_config.json
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"max_seq_length": null,
|
| 3 |
+
"do_lower_case": false
|
| 4 |
+
}
|
last-checkpoint/special_tokens_map.json
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token": {
|
| 3 |
+
"content": "<s>",
|
| 4 |
+
"lstrip": false,
|
| 5 |
+
"normalized": false,
|
| 6 |
+
"rstrip": false,
|
| 7 |
+
"single_word": false
|
| 8 |
+
},
|
| 9 |
+
"cls_token": {
|
| 10 |
+
"content": "<s>",
|
| 11 |
+
"lstrip": false,
|
| 12 |
+
"normalized": false,
|
| 13 |
+
"rstrip": false,
|
| 14 |
+
"single_word": false
|
| 15 |
+
},
|
| 16 |
+
"eos_token": {
|
| 17 |
+
"content": "</s>",
|
| 18 |
+
"lstrip": false,
|
| 19 |
+
"normalized": false,
|
| 20 |
+
"rstrip": false,
|
| 21 |
+
"single_word": false
|
| 22 |
+
},
|
| 23 |
+
"mask_token": {
|
| 24 |
+
"content": "<mask>",
|
| 25 |
+
"lstrip": false,
|
| 26 |
+
"normalized": false,
|
| 27 |
+
"rstrip": false,
|
| 28 |
+
"single_word": false
|
| 29 |
+
},
|
| 30 |
+
"pad_token": {
|
| 31 |
+
"content": "<pad>",
|
| 32 |
+
"lstrip": false,
|
| 33 |
+
"normalized": false,
|
| 34 |
+
"rstrip": false,
|
| 35 |
+
"single_word": false
|
| 36 |
+
},
|
| 37 |
+
"sep_token": {
|
| 38 |
+
"content": "</s>",
|
| 39 |
+
"lstrip": false,
|
| 40 |
+
"normalized": false,
|
| 41 |
+
"rstrip": false,
|
| 42 |
+
"single_word": false
|
| 43 |
+
},
|
| 44 |
+
"unk_token": {
|
| 45 |
+
"content": "<unk>",
|
| 46 |
+
"lstrip": false,
|
| 47 |
+
"normalized": false,
|
| 48 |
+
"rstrip": false,
|
| 49 |
+
"single_word": false
|
| 50 |
+
}
|
| 51 |
+
}
|
last-checkpoint/tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
last-checkpoint/tokenizer_config.json
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"added_tokens_decoder": {
|
| 3 |
+
"0": {
|
| 4 |
+
"content": "<unk>",
|
| 5 |
+
"lstrip": false,
|
| 6 |
+
"normalized": false,
|
| 7 |
+
"rstrip": false,
|
| 8 |
+
"single_word": false,
|
| 9 |
+
"special": true
|
| 10 |
+
},
|
| 11 |
+
"1": {
|
| 12 |
+
"content": "<s>",
|
| 13 |
+
"lstrip": false,
|
| 14 |
+
"normalized": false,
|
| 15 |
+
"rstrip": false,
|
| 16 |
+
"single_word": false,
|
| 17 |
+
"special": true
|
| 18 |
+
},
|
| 19 |
+
"2": {
|
| 20 |
+
"content": "</s>",
|
| 21 |
+
"lstrip": false,
|
| 22 |
+
"normalized": false,
|
| 23 |
+
"rstrip": false,
|
| 24 |
+
"single_word": false,
|
| 25 |
+
"special": true
|
| 26 |
+
},
|
| 27 |
+
"3": {
|
| 28 |
+
"content": "<pad>",
|
| 29 |
+
"lstrip": false,
|
| 30 |
+
"normalized": false,
|
| 31 |
+
"rstrip": false,
|
| 32 |
+
"single_word": false,
|
| 33 |
+
"special": true
|
| 34 |
+
},
|
| 35 |
+
"4": {
|
| 36 |
+
"content": "<mask>",
|
| 37 |
+
"lstrip": false,
|
| 38 |
+
"normalized": false,
|
| 39 |
+
"rstrip": false,
|
| 40 |
+
"single_word": false,
|
| 41 |
+
"special": true
|
| 42 |
+
},
|
| 43 |
+
"5": {
|
| 44 |
+
"content": "<special_0>",
|
| 45 |
+
"lstrip": false,
|
| 46 |
+
"normalized": false,
|
| 47 |
+
"rstrip": false,
|
| 48 |
+
"single_word": false,
|
| 49 |
+
"special": true
|
| 50 |
+
},
|
| 51 |
+
"6": {
|
| 52 |
+
"content": "<special_1>",
|
| 53 |
+
"lstrip": false,
|
| 54 |
+
"normalized": false,
|
| 55 |
+
"rstrip": false,
|
| 56 |
+
"single_word": false,
|
| 57 |
+
"special": true
|
| 58 |
+
},
|
| 59 |
+
"7": {
|
| 60 |
+
"content": "<special_2>",
|
| 61 |
+
"lstrip": false,
|
| 62 |
+
"normalized": false,
|
| 63 |
+
"rstrip": false,
|
| 64 |
+
"single_word": false,
|
| 65 |
+
"special": true
|
| 66 |
+
},
|
| 67 |
+
"8": {
|
| 68 |
+
"content": "<special_3>",
|
| 69 |
+
"lstrip": false,
|
| 70 |
+
"normalized": false,
|
| 71 |
+
"rstrip": false,
|
| 72 |
+
"single_word": false,
|
| 73 |
+
"special": true
|
| 74 |
+
},
|
| 75 |
+
"9": {
|
| 76 |
+
"content": "<special_4>",
|
| 77 |
+
"lstrip": false,
|
| 78 |
+
"normalized": false,
|
| 79 |
+
"rstrip": false,
|
| 80 |
+
"single_word": false,
|
| 81 |
+
"special": true
|
| 82 |
+
},
|
| 83 |
+
"10": {
|
| 84 |
+
"content": "<special_5>",
|
| 85 |
+
"lstrip": false,
|
| 86 |
+
"normalized": false,
|
| 87 |
+
"rstrip": false,
|
| 88 |
+
"single_word": false,
|
| 89 |
+
"special": true
|
| 90 |
+
},
|
| 91 |
+
"11": {
|
| 92 |
+
"content": "<special_6>",
|
| 93 |
+
"lstrip": false,
|
| 94 |
+
"normalized": false,
|
| 95 |
+
"rstrip": false,
|
| 96 |
+
"single_word": false,
|
| 97 |
+
"special": true
|
| 98 |
+
},
|
| 99 |
+
"12": {
|
| 100 |
+
"content": "<special_7>",
|
| 101 |
+
"lstrip": false,
|
| 102 |
+
"normalized": false,
|
| 103 |
+
"rstrip": false,
|
| 104 |
+
"single_word": false,
|
| 105 |
+
"special": true
|
| 106 |
+
},
|
| 107 |
+
"13": {
|
| 108 |
+
"content": "<special_8>",
|
| 109 |
+
"lstrip": false,
|
| 110 |
+
"normalized": false,
|
| 111 |
+
"rstrip": false,
|
| 112 |
+
"single_word": false,
|
| 113 |
+
"special": true
|
| 114 |
+
},
|
| 115 |
+
"14": {
|
| 116 |
+
"content": "<special_9>",
|
| 117 |
+
"lstrip": false,
|
| 118 |
+
"normalized": false,
|
| 119 |
+
"rstrip": false,
|
| 120 |
+
"single_word": false,
|
| 121 |
+
"special": true
|
| 122 |
+
},
|
| 123 |
+
"15": {
|
| 124 |
+
"content": "<special_10>",
|
| 125 |
+
"lstrip": false,
|
| 126 |
+
"normalized": false,
|
| 127 |
+
"rstrip": false,
|
| 128 |
+
"single_word": false,
|
| 129 |
+
"special": true
|
| 130 |
+
}
|
| 131 |
+
},
|
| 132 |
+
"bos_token": "<s>",
|
| 133 |
+
"clean_up_tokenization_spaces": false,
|
| 134 |
+
"cls_token": "<s>",
|
| 135 |
+
"eos_token": "</s>",
|
| 136 |
+
"extra_special_tokens": {},
|
| 137 |
+
"mask_token": "<mask>",
|
| 138 |
+
"model_max_length": 4096,
|
| 139 |
+
"pad_token": "<pad>",
|
| 140 |
+
"sep_token": "</s>",
|
| 141 |
+
"tokenizer_class": "PreTrainedTokenizerFast",
|
| 142 |
+
"unk_token": "<unk>"
|
| 143 |
+
}
|
last-checkpoint/trainer_state.json
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"best_global_step": 500,
|
| 3 |
+
"best_metric": 0.027178706104522946,
|
| 4 |
+
"best_model_checkpoint": "models/splade-norbert4-base-retrieval-only/checkpoint-500",
|
| 5 |
+
"epoch": 0.04797083373309028,
|
| 6 |
+
"eval_steps": 500,
|
| 7 |
+
"global_step": 500,
|
| 8 |
+
"is_hyper_param_search": false,
|
| 9 |
+
"is_local_process_zero": true,
|
| 10 |
+
"is_world_process_zero": true,
|
| 11 |
+
"log_history": [
|
| 12 |
+
{
|
| 13 |
+
"base_loss": 37895.6328,
|
| 14 |
+
"document_regularizer_loss": 0.0598,
|
| 15 |
+
"epoch": 0.004797083373309028,
|
| 16 |
+
"grad_norm": 259787.75,
|
| 17 |
+
"learning_rate": 9.395973154362417e-07,
|
| 18 |
+
"loss": 37895.69,
|
| 19 |
+
"query_regularizer_loss": 0.001,
|
| 20 |
+
"step": 50
|
| 21 |
+
},
|
| 22 |
+
{
|
| 23 |
+
"base_loss": 10001.6025,
|
| 24 |
+
"document_regularizer_loss": 0.4482,
|
| 25 |
+
"epoch": 0.009594166746618057,
|
| 26 |
+
"grad_norm": 505128.8125,
|
| 27 |
+
"learning_rate": 1.8983700862895495e-06,
|
| 28 |
+
"loss": 10002.0562,
|
| 29 |
+
"query_regularizer_loss": 0.0037,
|
| 30 |
+
"step": 100
|
| 31 |
+
},
|
| 32 |
+
{
|
| 33 |
+
"base_loss": 3804.3779,
|
| 34 |
+
"document_regularizer_loss": 1.0922,
|
| 35 |
+
"epoch": 0.014391250119927085,
|
| 36 |
+
"grad_norm": 53786.4296875,
|
| 37 |
+
"learning_rate": 2.8571428571428573e-06,
|
| 38 |
+
"loss": 3805.4731,
|
| 39 |
+
"query_regularizer_loss": 0.0023,
|
| 40 |
+
"step": 150
|
| 41 |
+
},
|
| 42 |
+
{
|
| 43 |
+
"base_loss": 921.3414,
|
| 44 |
+
"document_regularizer_loss": 1.7523,
|
| 45 |
+
"epoch": 0.019188333493236114,
|
| 46 |
+
"grad_norm": 48469.69140625,
|
| 47 |
+
"learning_rate": 3.815915627996165e-06,
|
| 48 |
+
"loss": 923.0944,
|
| 49 |
+
"query_regularizer_loss": 0.0007,
|
| 50 |
+
"step": 200
|
| 51 |
+
},
|
| 52 |
+
{
|
| 53 |
+
"base_loss": 512.5709,
|
| 54 |
+
"document_regularizer_loss": 2.2081,
|
| 55 |
+
"epoch": 0.02398541686654514,
|
| 56 |
+
"grad_norm": 9211.7822265625,
|
| 57 |
+
"learning_rate": 4.774688398849473e-06,
|
| 58 |
+
"loss": 514.7795,
|
| 59 |
+
"query_regularizer_loss": 0.0005,
|
| 60 |
+
"step": 250
|
| 61 |
+
},
|
| 62 |
+
{
|
| 63 |
+
"base_loss": 282.497,
|
| 64 |
+
"document_regularizer_loss": 2.0475,
|
| 65 |
+
"epoch": 0.02878250023985417,
|
| 66 |
+
"grad_norm": 430.40460205078125,
|
| 67 |
+
"learning_rate": 5.733461169702781e-06,
|
| 68 |
+
"loss": 284.5449,
|
| 69 |
+
"query_regularizer_loss": 0.0003,
|
| 70 |
+
"step": 300
|
| 71 |
+
},
|
| 72 |
+
{
|
| 73 |
+
"base_loss": 88.6157,
|
| 74 |
+
"document_regularizer_loss": 1.4521,
|
| 75 |
+
"epoch": 0.0335795836131632,
|
| 76 |
+
"grad_norm": 660.3614501953125,
|
| 77 |
+
"learning_rate": 6.692233940556089e-06,
|
| 78 |
+
"loss": 90.0678,
|
| 79 |
+
"query_regularizer_loss": 0.0001,
|
| 80 |
+
"step": 350
|
| 81 |
+
},
|
| 82 |
+
{
|
| 83 |
+
"base_loss": 30.6636,
|
| 84 |
+
"document_regularizer_loss": 0.1846,
|
| 85 |
+
"epoch": 0.03837666698647223,
|
| 86 |
+
"grad_norm": 3.8934402465820312,
|
| 87 |
+
"learning_rate": 7.651006711409396e-06,
|
| 88 |
+
"loss": 30.8482,
|
| 89 |
+
"query_regularizer_loss": 0.0,
|
| 90 |
+
"step": 400
|
| 91 |
+
},
|
| 92 |
+
{
|
| 93 |
+
"base_loss": 2.5066,
|
| 94 |
+
"document_regularizer_loss": 0.0005,
|
| 95 |
+
"epoch": 0.04317375035978125,
|
| 96 |
+
"grad_norm": 26.094982147216797,
|
| 97 |
+
"learning_rate": 8.609779482262704e-06,
|
| 98 |
+
"loss": 2.5071,
|
| 99 |
+
"query_regularizer_loss": 0.0,
|
| 100 |
+
"step": 450
|
| 101 |
+
},
|
| 102 |
+
{
|
| 103 |
+
"base_loss": 1.3518,
|
| 104 |
+
"document_regularizer_loss": 0.0007,
|
| 105 |
+
"epoch": 0.04797083373309028,
|
| 106 |
+
"grad_norm": 20.548200607299805,
|
| 107 |
+
"learning_rate": 9.568552253116012e-06,
|
| 108 |
+
"loss": 1.3525,
|
| 109 |
+
"query_regularizer_loss": 0.0,
|
| 110 |
+
"step": 500
|
| 111 |
+
},
|
| 112 |
+
{
|
| 113 |
+
"epoch": 0.04797083373309028,
|
| 114 |
+
"eval_NanoBEIR_mean_avg_flops": 51200.0,
|
| 115 |
+
"eval_NanoBEIR_mean_corpus_active_dims": 51200.0,
|
| 116 |
+
"eval_NanoBEIR_mean_corpus_sparsity_ratio": 0.0,
|
| 117 |
+
"eval_NanoBEIR_mean_dot_accuracy@1": 0.02,
|
| 118 |
+
"eval_NanoBEIR_mean_dot_accuracy@10": 0.12,
|
| 119 |
+
"eval_NanoBEIR_mean_dot_accuracy@3": 0.08,
|
| 120 |
+
"eval_NanoBEIR_mean_dot_accuracy@5": 0.08,
|
| 121 |
+
"eval_NanoBEIR_mean_dot_map@100": 0.006747512755501429,
|
| 122 |
+
"eval_NanoBEIR_mean_dot_mrr@10": 0.05088888888888889,
|
| 123 |
+
"eval_NanoBEIR_mean_dot_ndcg@10": 0.027178706104522946,
|
| 124 |
+
"eval_NanoBEIR_mean_dot_precision@1": 0.02,
|
| 125 |
+
"eval_NanoBEIR_mean_dot_precision@10": 0.026000000000000006,
|
| 126 |
+
"eval_NanoBEIR_mean_dot_precision@3": 0.03333333333333333,
|
| 127 |
+
"eval_NanoBEIR_mean_dot_precision@5": 0.032,
|
| 128 |
+
"eval_NanoBEIR_mean_dot_recall@1": 7.905138339920947e-05,
|
| 129 |
+
"eval_NanoBEIR_mean_dot_recall@10": 0.006349071275176555,
|
| 130 |
+
"eval_NanoBEIR_mean_dot_recall@3": 0.003312410422185988,
|
| 131 |
+
"eval_NanoBEIR_mean_dot_recall@5": 0.004545769460972766,
|
| 132 |
+
"eval_NanoBEIR_mean_query_active_dims": 51200.0,
|
| 133 |
+
"eval_NanoBEIR_mean_query_sparsity_ratio": 0.0,
|
| 134 |
+
"eval_NanoNFCorpus_avg_flops": 51200.0,
|
| 135 |
+
"eval_NanoNFCorpus_corpus_active_dims": 51200.0,
|
| 136 |
+
"eval_NanoNFCorpus_corpus_sparsity_ratio": 0.0,
|
| 137 |
+
"eval_NanoNFCorpus_dot_accuracy@1": 0.02,
|
| 138 |
+
"eval_NanoNFCorpus_dot_accuracy@10": 0.12,
|
| 139 |
+
"eval_NanoNFCorpus_dot_accuracy@3": 0.08,
|
| 140 |
+
"eval_NanoNFCorpus_dot_accuracy@5": 0.08,
|
| 141 |
+
"eval_NanoNFCorpus_dot_map@100": 0.006747512755501429,
|
| 142 |
+
"eval_NanoNFCorpus_dot_mrr@10": 0.05088888888888889,
|
| 143 |
+
"eval_NanoNFCorpus_dot_ndcg@10": 0.027178706104522946,
|
| 144 |
+
"eval_NanoNFCorpus_dot_precision@1": 0.02,
|
| 145 |
+
"eval_NanoNFCorpus_dot_precision@10": 0.026000000000000006,
|
| 146 |
+
"eval_NanoNFCorpus_dot_precision@3": 0.03333333333333333,
|
| 147 |
+
"eval_NanoNFCorpus_dot_precision@5": 0.032,
|
| 148 |
+
"eval_NanoNFCorpus_dot_recall@1": 7.905138339920947e-05,
|
| 149 |
+
"eval_NanoNFCorpus_dot_recall@10": 0.006349071275176555,
|
| 150 |
+
"eval_NanoNFCorpus_dot_recall@3": 0.003312410422185988,
|
| 151 |
+
"eval_NanoNFCorpus_dot_recall@5": 0.004545769460972766,
|
| 152 |
+
"eval_NanoNFCorpus_query_active_dims": 51200.0,
|
| 153 |
+
"eval_NanoNFCorpus_query_sparsity_ratio": 0.0,
|
| 154 |
+
"eval_base_loss": 2.2657,
|
| 155 |
+
"eval_document_regularizer_loss": 0.0006,
|
| 156 |
+
"eval_loss": 2.2663323879241943,
|
| 157 |
+
"eval_query_regularizer_loss": 0.0,
|
| 158 |
+
"eval_runtime": 364.216,
|
| 159 |
+
"eval_samples_per_second": 39.696,
|
| 160 |
+
"eval_steps_per_second": 0.621,
|
| 161 |
+
"step": 500
|
| 162 |
+
}
|
| 163 |
+
],
|
| 164 |
+
"logging_steps": 50,
|
| 165 |
+
"max_steps": 10423,
|
| 166 |
+
"num_input_tokens_seen": 0,
|
| 167 |
+
"num_train_epochs": 1,
|
| 168 |
+
"save_steps": 500,
|
| 169 |
+
"stateful_callbacks": {
|
| 170 |
+
"TrainerControl": {
|
| 171 |
+
"args": {
|
| 172 |
+
"should_epoch_stop": false,
|
| 173 |
+
"should_evaluate": false,
|
| 174 |
+
"should_log": false,
|
| 175 |
+
"should_save": true,
|
| 176 |
+
"should_training_stop": false
|
| 177 |
+
},
|
| 178 |
+
"attributes": {}
|
| 179 |
+
}
|
| 180 |
+
},
|
| 181 |
+
"total_flos": 0.0,
|
| 182 |
+
"train_batch_size": 16,
|
| 183 |
+
"trial_name": null,
|
| 184 |
+
"trial_params": null
|
| 185 |
+
}
|
last-checkpoint/training_args.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:80a238e2a37bc1da8ffa7f6ac3192a8f23e297afa1bc8b0a28a0d40a8e101359
|
| 3 |
+
size 6353
|