Upload folder using huggingface_hub
Browse files- README.md +161 -0
- __init__.py +0 -0
- config.json +39 -0
- configuration_gptbert.py +34 -0
- modeling_gptbert.py +1123 -0
- pytorch_model.bin +3 -0
- special_tokens_map.json +1 -0
- tokenizer.json +0 -0
- tokenizer_config.json +10 -0
README.md
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
language:
|
| 3 |
+
- ka
|
| 4 |
+
- kat
|
| 5 |
+
inference: false
|
| 6 |
+
tags:
|
| 7 |
+
- BERT
|
| 8 |
+
- HPLT
|
| 9 |
+
- encoder
|
| 10 |
+
- text2text-generation
|
| 11 |
+
license: apache-2.0
|
| 12 |
+
datasets:
|
| 13 |
+
- HPLT/HPLT3.0
|
| 14 |
+
---
|
| 15 |
+
|
| 16 |
+
# HPLT v3.0 GPT-BERT for Georgian
|
| 17 |
+
|
| 18 |
+
<img src="https://hplt-project.org/_next/static/media/logo-hplt.d5e16ca5.svg" width=12.5%>
|
| 19 |
+
|
| 20 |
+
This is one of the monolingual language models trained as a third release by the [HPLT project](https://hplt-project.org/).
|
| 21 |
+
Our models follow the setup of [GPT-BERT](https://aclanthology.org/2024.conll-babylm.24/).
|
| 22 |
+
|
| 23 |
+
All the HPLT GPT-BERT models use the same hyper-parameters:
|
| 24 |
+
- hidden size: 640
|
| 25 |
+
- attention heads: 10
|
| 26 |
+
- layers: 24
|
| 27 |
+
- vocabulary size: 32768
|
| 28 |
+
|
| 29 |
+
Every model uses its own tokenizer trained on language-specific HPLT data.
|
| 30 |
+
|
| 31 |
+
[The training code](https://github.com/ltgoslo/NorBERT/tree/main/norbert4).
|
| 32 |
+
|
| 33 |
+
## Example usage (bidirectional encoding)
|
| 34 |
+
|
| 35 |
+
This model currently needs a custom wrapper from `modeling_gptbert.py`, you should therefore load the model with `trust_remote_code=True`.
|
| 36 |
+
|
| 37 |
+
```python
|
| 38 |
+
import torch
|
| 39 |
+
from transformers import AutoTokenizer, AutoModelForMaskedLM
|
| 40 |
+
|
| 41 |
+
# Import model
|
| 42 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 43 |
+
"HPLT/hplt_gpt_bert_base_3_0_kat_Geor",
|
| 44 |
+
)
|
| 45 |
+
model = AutoModelForMaskedLM.from_pretrained(
|
| 46 |
+
"HPLT/hplt_gpt_bert_base_3_0_kat_Geor",
|
| 47 |
+
trust_remote_code=True,
|
| 48 |
+
use_safetensors=False,
|
| 49 |
+
)
|
| 50 |
+
model = model.eval()
|
| 51 |
+
input_text = f"Norwegian is a {tokenizer.mask_token} Germanic language"
|
| 52 |
+
print(input_text)
|
| 53 |
+
# Tokenize text (with a mask token inside)
|
| 54 |
+
input_text = tokenizer(
|
| 55 |
+
input_text,
|
| 56 |
+
return_tensors="pt",
|
| 57 |
+
)
|
| 58 |
+
# Inference
|
| 59 |
+
with torch.no_grad():
|
| 60 |
+
output_p = model(**input_text)
|
| 61 |
+
|
| 62 |
+
# Unmask the text
|
| 63 |
+
output_text = torch.where(
|
| 64 |
+
input_text.input_ids == tokenizer.mask_token_id,
|
| 65 |
+
output_p.logits.argmax(-1),
|
| 66 |
+
input_text.input_ids
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
# Decoding; should output: 'Norwegian is a North Germanic language'
|
| 70 |
+
print(tokenizer.decode(output_text[0].tolist()))
|
| 71 |
+
```
|
| 72 |
+
|
| 73 |
+
## Example usage (text generation)
|
| 74 |
+
|
| 75 |
+
GPT-BERT also supports unidirectional text decoding, it can generate text like any other GPT model:
|
| 76 |
+
|
| 77 |
+
```python
|
| 78 |
+
import torch
|
| 79 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 80 |
+
|
| 81 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 82 |
+
"HPLT/hplt_gpt_bert_base_3_0_kat_Geor",
|
| 83 |
+
)
|
| 84 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 85 |
+
"HPLT/hplt_gpt_bert_base_3_0_kat_Geor",
|
| 86 |
+
trust_remote_code=True,
|
| 87 |
+
use_safetensors=False,
|
| 88 |
+
)
|
| 89 |
+
text = f"The Norwegian Constitution"
|
| 90 |
+
print(text, flush=True)
|
| 91 |
+
# Define tokens that should end the generation
|
| 92 |
+
eos_token_ids = [
|
| 93 |
+
token_id
|
| 94 |
+
for token_id in range(tokenizer.vocab_size)
|
| 95 |
+
if '.' in tokenizer.decode([token_id])
|
| 96 |
+
]
|
| 97 |
+
|
| 98 |
+
# Generation function
|
| 99 |
+
@torch.no_grad()
|
| 100 |
+
def generate(text):
|
| 101 |
+
input_ids = tokenizer(text, return_tensors='pt').input_ids
|
| 102 |
+
prediction = model.generate(
|
| 103 |
+
input_ids,
|
| 104 |
+
max_new_tokens=63,
|
| 105 |
+
do_sample=False,
|
| 106 |
+
eos_token_id=eos_token_ids,
|
| 107 |
+
)
|
| 108 |
+
return tokenizer.decode(prediction[0]).strip()
|
| 109 |
+
|
| 110 |
+
# Example usage, should output '[CLS]The Norwegian Constitution[SEP]is a document that defines the rights and responsibilities of the Norwegian people and their representatives.'
|
| 111 |
+
print(generate(text), flush=True)
|
| 112 |
+
```
|
| 113 |
+
|
| 114 |
+
The following classes are currently implemented: `AutoModel`, `AutoModelForMaskedLM`, `AutoModelForCausalLM`, `AutoModelForSequenceClassification`, `AutoModelForTokenClassification`, `AutoModelForQuestionAnswering` and `AutoModeltForMultipleChoice`.
|
| 115 |
+
|
| 116 |
+
## Intermediate checkpoints
|
| 117 |
+
|
| 118 |
+
We are releasing 10 intermediate checkpoints for each model at intervals of every 3125 training steps in separate branches. The naming convention is `stepXXX`: for example, `step18750`.
|
| 119 |
+
|
| 120 |
+
You can load a specific model revision with `transformers` using the argument `revision`:
|
| 121 |
+
```python
|
| 122 |
+
model = AutoModelForSeq2SeqLM.from_pretrained("HPLT/hplt_gpt_bert_base_3_0_kat_Geor", revision="step21875", trust_remote_code=True)
|
| 123 |
+
```
|
| 124 |
+
|
| 125 |
+
You can access all the revisions for the models with the following code:
|
| 126 |
+
```python
|
| 127 |
+
from huggingface_hub import list_repo_refs
|
| 128 |
+
out = list_repo_refs("HPLT/hplt_gpt_bert_base_3_0_kat_Geor")
|
| 129 |
+
print([b.name for b in out.branches])
|
| 130 |
+
```
|
| 131 |
+
|
| 132 |
+
## Cite us
|
| 133 |
+
|
| 134 |
+
```bibtex
|
| 135 |
+
@inproceedings{charpentier-samuel-2024-bert,
|
| 136 |
+
title = "{GPT} or {BERT}: why not both?",
|
| 137 |
+
author = "Charpentier, Lucas Georges Gabriel and
|
| 138 |
+
Samuel, David",
|
| 139 |
+
booktitle = "The 2nd BabyLM Challenge at the 28th Conference on Computational Natural Language Learning",
|
| 140 |
+
month = nov,
|
| 141 |
+
year = "2024",
|
| 142 |
+
address = "Miami, FL, USA",
|
| 143 |
+
publisher = "Association for Computational Linguistics",
|
| 144 |
+
url = "https://aclanthology.org/2024.conll-babylm.24/",
|
| 145 |
+
pages = "262--283"
|
| 146 |
+
}
|
| 147 |
+
```
|
| 148 |
+
|
| 149 |
+
```bibtex
|
| 150 |
+
@misc{oepen2025hplt30largescalemultilingual,
|
| 151 |
+
title={{HPLT 3.0}: {V}ery Large-Scale Multilingual Resources for {LLM} and {MT}. Mono- and Bi-lingual Data, Multilingual Evaluation, and Pre-Trained Models},
|
| 152 |
+
author={Stephan Oepen and Nikolay Arefev and Mikko Aulamo and Marta Bañón and Maja Buljan and Laurie Burchell and Lucas Charpentier and Pinzhen Chen and Mariia Fedorova and Ona de Gibert and Barry Haddow and Jan Hajič and Jindřich Helcl and Andrey Kutuzov and Veronika Laippala and Zihao Li and Risto Luukkonen and Bhavitvya Malik and Vladislav Mikhailov and Amanda Myntti and Dayyán O'Brien and Lucie Poláková and Sampo Pyysalo and Gema Ramírez Sánchez and Janine Siewert and Pavel Stepachev and Jörg Tiedemann and Teemu Vahtola and Dušan Variš and Fedor Vitiugin and Tea Vojtěchová and Jaume Zaragoza},
|
| 153 |
+
year={2025},
|
| 154 |
+
eprint={2511.01066},
|
| 155 |
+
archivePrefix={arXiv},
|
| 156 |
+
primaryClass={cs.CL},
|
| 157 |
+
url={https://arxiv.org/abs/2511.01066},
|
| 158 |
+
}
|
| 159 |
+
```
|
| 160 |
+
[](https://arxiv.org/abs/2410.24159)
|
| 161 |
+
[](https://arxiv.org/abs/2511.01066)
|
__init__.py
ADDED
|
File without changes
|
config.json
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"GptBertForMaskedLM"
|
| 4 |
+
],
|
| 5 |
+
"auto_map": {
|
| 6 |
+
"AutoConfig": "configuration_gptbert.GptBertConfig",
|
| 7 |
+
"AutoModel": "modeling_gptbert.GptBertModel",
|
| 8 |
+
"AutoModelForCausalLM": "modeling_gptbert.GptBertForCausalLM",
|
| 9 |
+
"AutoModelForMaskedLM": "modeling_gptbert.GptBertForMaskedLM",
|
| 10 |
+
"AutoModelForSequenceClassification": "modeling_gptbert.GptBertForSequenceClassification",
|
| 11 |
+
"AutoModelForTokenClassification": "modeling_gptbert.GptBertForTokenClassification",
|
| 12 |
+
"AutoModelForQuestionAnswering": "modeling_gptbert.GptBertForQuestionAnswering",
|
| 13 |
+
"AutoModelForMultipleChoice": "modeling_gptbert.GptBertForMultipleChoice"
|
| 14 |
+
},
|
| 15 |
+
"unk_token_id": 1,
|
| 16 |
+
"bos_token_id": 2,
|
| 17 |
+
"eos_token_id": 3,
|
| 18 |
+
"pad_token_id": 0,
|
| 19 |
+
"mask_token_id": 4,
|
| 20 |
+
"hidden_size": 640,
|
| 21 |
+
"intermediate_size": 1664,
|
| 22 |
+
"max_sequence_length": 16384,
|
| 23 |
+
"num_layers": 24,
|
| 24 |
+
"attention_dropout": 0.0,
|
| 25 |
+
"hidden_dropout": 0.0,
|
| 26 |
+
"embedding_dropout": 0.1,
|
| 27 |
+
"classifier_dropout": 0.2,
|
| 28 |
+
"layer_norm_eps": 1e-07,
|
| 29 |
+
"query_key_head_size": 64,
|
| 30 |
+
"value_head_size": 64,
|
| 31 |
+
"num_attention_heads": 10,
|
| 32 |
+
"rope_theta": 160000,
|
| 33 |
+
"vocab_size": 32768,
|
| 34 |
+
"local_global_ratio": 4,
|
| 35 |
+
"global_window_length": 8192,
|
| 36 |
+
"local_window_length": 256,
|
| 37 |
+
"deterministic_flash_attn": false,
|
| 38 |
+
"use_cache": false
|
| 39 |
+
}
|
configuration_gptbert.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
import copy
|
| 6 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class GptBertConfig(PretrainedConfig):
|
| 10 |
+
|
| 11 |
+
def __init__(
|
| 12 |
+
self,
|
| 13 |
+
config_file: Path | str | None = None,
|
| 14 |
+
**kwargs
|
| 15 |
+
):
|
| 16 |
+
super().__init__(**kwargs)
|
| 17 |
+
self.model = "norbert4"
|
| 18 |
+
|
| 19 |
+
if config_file is not None:
|
| 20 |
+
if type(config_file) is str:
|
| 21 |
+
config_file = Path(config_file)
|
| 22 |
+
assert type(config_file) is not Path, "The config_file should either be a Path or str"
|
| 23 |
+
with config_file.open("r") as file:
|
| 24 |
+
config = json.load(file)
|
| 25 |
+
|
| 26 |
+
for attr, value in config.items():
|
| 27 |
+
if isinstance(value, str):
|
| 28 |
+
value = value.lower()
|
| 29 |
+
setattr(self, attr, value)
|
| 30 |
+
|
| 31 |
+
for attr, value in kwargs.items():
|
| 32 |
+
if isinstance(value, str):
|
| 33 |
+
value = value.lower()
|
| 34 |
+
setattr(self, attr, value)
|
modeling_gptbert.py
ADDED
|
@@ -0,0 +1,1123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from torch.nn import functional as F
|
| 6 |
+
from torch import _softmax_backward_data as _softmax_backward_data
|
| 7 |
+
|
| 8 |
+
from functools import partial, lru_cache
|
| 9 |
+
|
| 10 |
+
from .configuration_gptbert import GptBertConfig
|
| 11 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 12 |
+
from transformers.activations import gelu_new
|
| 13 |
+
from transformers.utils import is_flash_attn_2_available, logging
|
| 14 |
+
from transformers.modeling_outputs import (
|
| 15 |
+
MaskedLMOutput,
|
| 16 |
+
MultipleChoiceModelOutput,
|
| 17 |
+
QuestionAnsweringModelOutput,
|
| 18 |
+
SequenceClassifierOutput,
|
| 19 |
+
TokenClassifierOutput,
|
| 20 |
+
BaseModelOutput,
|
| 21 |
+
CausalLMOutput
|
| 22 |
+
)
|
| 23 |
+
import math
|
| 24 |
+
from typing import TYPE_CHECKING, Optional, Union, Tuple, List
|
| 25 |
+
|
| 26 |
+
logger = logging.get_logger(__name__)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# Workaround for transformers < 4.36.0 check_imports issue
|
| 30 |
+
# See: https://github.com/huggingface/transformers/issues/28459
|
| 31 |
+
try:
|
| 32 |
+
if is_flash_attn_2_available():
|
| 33 |
+
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
|
| 34 |
+
from flash_attn.layers.rotary import RotaryEmbedding
|
| 35 |
+
from flash_attn.ops.triton.rotary import apply_rotary
|
| 36 |
+
else:
|
| 37 |
+
flash_attn_varlen_qkvpacked_func, RotaryEmbedding, apply_rotary = None, object, None
|
| 38 |
+
logger.warning_once(
|
| 39 |
+
"NorBERT4 støtter FlashAttention, men det er ikke funnet i miljøet ditt. Du bør vurdere å oppdatere miljøet ditt for å få raskere og mindre minnekrevende behandling."
|
| 40 |
+
)
|
| 41 |
+
except ImportError:
|
| 42 |
+
flash_attn_varlen_qkvpacked_func, RotaryEmbedding, apply_rotary = None, object, None
|
| 43 |
+
logger.warning_once(
|
| 44 |
+
"NorBERT4 støtter FlashAttention, men det er ikke funnet i miljøet ditt. Du bør vurdere å oppdatere miljøet ditt for å få raskere og mindre minnekrevende behandling."
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# from https://github.com/huggingface/transformers/blob/main/src/transformers/models/modernbert/modeling_modernbert.py
|
| 49 |
+
@torch.compiler.disable()
|
| 50 |
+
def _unpad_input(input_ids: torch.Tensor, attention_mask: torch.Tensor):
|
| 51 |
+
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
| 52 |
+
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
| 53 |
+
max_seqlen_in_batch = int(seqlens_in_batch.max().item())
|
| 54 |
+
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
| 55 |
+
|
| 56 |
+
if input_ids.dim() == 2:
|
| 57 |
+
unpadded_inputs = input_ids.flatten()[indices]
|
| 58 |
+
else:
|
| 59 |
+
batch_size, sequence_length, *rest = input_ids.shape
|
| 60 |
+
shape = batch_size * sequence_length
|
| 61 |
+
unpadded_inputs = input_ids.view(shape, *rest)[indices]
|
| 62 |
+
|
| 63 |
+
return unpadded_inputs, indices, cu_seqlens, max_seqlen_in_batch
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
# from https://github.com/huggingface/transformers/blob/main/src/transformers/models/modernbert/modeling_modernbert.py
|
| 67 |
+
def _pad_output(input_ids: torch.Tensor, indices: torch.Tensor, batch_size: int, sequence_length: int) -> torch.Tensor:
|
| 68 |
+
if input_ids.dim() == 1:
|
| 69 |
+
output = torch.zeros(batch_size * sequence_length, dtype=input_ids.dtype, device=input_ids.device)
|
| 70 |
+
output[indices] = input_ids
|
| 71 |
+
padded_inputs = output.view(batch_size, sequence_length)
|
| 72 |
+
else:
|
| 73 |
+
_, *rest = input_ids.shape
|
| 74 |
+
output = torch.zeros(batch_size * sequence_length, *rest, dtype=input_ids.dtype, device=input_ids.device)
|
| 75 |
+
output[indices] = input_ids
|
| 76 |
+
padded_inputs = output.view(batch_size, sequence_length, *rest)
|
| 77 |
+
|
| 78 |
+
return padded_inputs
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class CastedLinear(nn.Linear):
|
| 82 |
+
def __init__(self, in_features, out_features, bias):
|
| 83 |
+
super().__init__(in_features, out_features, bias=bias)
|
| 84 |
+
|
| 85 |
+
def forward(self, x):
|
| 86 |
+
return F.linear(x, self.weight.type_as(x), bias=self.bias.type_as(x) if self.bias is not None else None)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class CastedLinearIn(nn.Linear):
|
| 90 |
+
def __init__(self, in_features, out_features, bias):
|
| 91 |
+
super().__init__(in_features, out_features, bias=bias)
|
| 92 |
+
self.scale = nn.Parameter(torch.ones(in_features))
|
| 93 |
+
|
| 94 |
+
def forward(self, x):
|
| 95 |
+
return F.linear(x, (self.weight * (self.scale + 1.0).unsqueeze(0)).type_as(x), bias=self.bias.type_as(x) if self.bias is not None else None)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class MultiCastedLinearOrthoIn(nn.Module):
|
| 99 |
+
def __init__(self, in_features, out_features, bias):
|
| 100 |
+
super().__init__()
|
| 101 |
+
|
| 102 |
+
self.in_features = in_features
|
| 103 |
+
self.out_features = out_features
|
| 104 |
+
|
| 105 |
+
self.weights = nn.ParameterList()
|
| 106 |
+
for out_feature in out_features:
|
| 107 |
+
self.weights.append(nn.Parameter(torch.empty((out_feature, in_features))))
|
| 108 |
+
|
| 109 |
+
if bias:
|
| 110 |
+
self.bias = nn.Parameter(torch.zeros(sum(out_features)))
|
| 111 |
+
else:
|
| 112 |
+
self.bias = self.register_parameter("bias", None)
|
| 113 |
+
|
| 114 |
+
self.scale = nn.Parameter(torch.ones(in_features))
|
| 115 |
+
|
| 116 |
+
def forward(self, x):
|
| 117 |
+
return F.linear(x, (torch.cat([weight for weight in self.weights], dim=0) * (self.scale + 1.0).unsqueeze(0)).type_as(x), bias=self.bias.type_as(x) if self.bias is not None else None)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
class GeGLU(nn.Module):
|
| 121 |
+
def forward(self, x):
|
| 122 |
+
x, gate = x.chunk(2, dim=-1)
|
| 123 |
+
return x * gelu_new(gate)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
class Embedding(nn.Module):
|
| 127 |
+
def __init__(self, config: GptBertConfig):
|
| 128 |
+
super().__init__()
|
| 129 |
+
|
| 130 |
+
self.word_embedding = nn.Embedding(config.vocab_size, config.hidden_size)
|
| 131 |
+
self.word_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False, bias=False)
|
| 132 |
+
self.word_scale = nn.Parameter(torch.zeros(config.hidden_size))
|
| 133 |
+
self.dropout = nn.Dropout(config.embedding_dropout)
|
| 134 |
+
|
| 135 |
+
def forward(self, input_ids: torch.Tensor):
|
| 136 |
+
word_embedding = self.word_embedding(input_ids)
|
| 137 |
+
word_embedding = self.word_norm(word_embedding)
|
| 138 |
+
word_embedding = word_embedding * (self.word_scale + 1.0)
|
| 139 |
+
|
| 140 |
+
return self.dropout(word_embedding)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
class LMClassifier(nn.Module):
|
| 144 |
+
def __init__(self, config: GptBertConfig, n_labels: int):
|
| 145 |
+
super().__init__()
|
| 146 |
+
|
| 147 |
+
self.pre_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False)
|
| 148 |
+
self.projection = CastedLinearIn(config.hidden_size, config.hidden_size, bias=False)
|
| 149 |
+
self.post_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False)
|
| 150 |
+
self.emb2vocab = CastedLinearIn(config.hidden_size, n_labels, bias=True)
|
| 151 |
+
|
| 152 |
+
def forward(self, x: torch.Tensor):
|
| 153 |
+
x = self.pre_norm(x.float()).type_as(x)
|
| 154 |
+
x = self.projection(x)
|
| 155 |
+
x = gelu_new(x)
|
| 156 |
+
x = self.post_norm(x.float()).type_as(x)
|
| 157 |
+
x = self.emb2vocab(x)
|
| 158 |
+
return x
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
class Classifier(nn.Module):
|
| 162 |
+
def __init__(self, config: GptBertConfig, n_labels: int):
|
| 163 |
+
super().__init__()
|
| 164 |
+
|
| 165 |
+
self.pre_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False)
|
| 166 |
+
self.projection = CastedLinearIn(config.hidden_size, config.hidden_size, bias=False)
|
| 167 |
+
self.post_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False)
|
| 168 |
+
self.dropout = nn.Dropout(config.classifier_dropout)
|
| 169 |
+
self.output_projection = CastedLinearIn(config.hidden_size, n_labels, bias=True)
|
| 170 |
+
|
| 171 |
+
def forward(self, x: torch.Tensor):
|
| 172 |
+
x = self.pre_norm(x.float()).type_as(x)
|
| 173 |
+
x = self.projection(x)
|
| 174 |
+
x = gelu_new(x)
|
| 175 |
+
x = self.post_norm(x.float()).type_as(x)
|
| 176 |
+
x = self.dropout(x)
|
| 177 |
+
x = self.output_projection(x)
|
| 178 |
+
return x
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
# from https://github.com/huggingface/transformers/blob/main/src/transformers/models/modernbert/modeling_modernbert.py
|
| 182 |
+
def flash_attention_forward(qkv: torch.Tensor, rotary_emb: UnpaddedRotaryEmbedding, cu_seqlens: torch.Tensor, max_seqlen: int, causal: bool, local_attention: Tuple[int, int], dropout_p: float, deterministic: bool, target_dtype: torch.dtype = torch.bfloat16, **_kwargs):
|
| 183 |
+
qkv = rotary_emb(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)
|
| 184 |
+
|
| 185 |
+
convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16)
|
| 186 |
+
if convert_dtype:
|
| 187 |
+
# FA2 implementation only supports fp16 and bf16. If FA2 is supported,
|
| 188 |
+
# bfloat16 must be supported as of FA2 2.5.7. (Turing GPUs not supported)
|
| 189 |
+
orig_dtype = qkv.dtype
|
| 190 |
+
qkv = qkv.to(target_dtype)
|
| 191 |
+
|
| 192 |
+
attn = flash_attn_varlen_qkvpacked_func(
|
| 193 |
+
qkv,
|
| 194 |
+
cu_seqlens=cu_seqlens,
|
| 195 |
+
max_seqlen=max_seqlen,
|
| 196 |
+
dropout_p=dropout_p,
|
| 197 |
+
deterministic=deterministic,
|
| 198 |
+
window_size=local_attention,
|
| 199 |
+
causal=False
|
| 200 |
+
)
|
| 201 |
+
attn = attn.to(orig_dtype) # type: ignore
|
| 202 |
+
else:
|
| 203 |
+
attn = flash_attn_varlen_qkvpacked_func(
|
| 204 |
+
qkv,
|
| 205 |
+
cu_seqlens=cu_seqlens,
|
| 206 |
+
max_seqlen=max_seqlen,
|
| 207 |
+
dropout_p=dropout_p,
|
| 208 |
+
deterministic=deterministic,
|
| 209 |
+
window_size=local_attention,
|
| 210 |
+
causal=False
|
| 211 |
+
)
|
| 212 |
+
return attn
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
# from https://github.com/huggingface/transformers/blob/main/src/transformers/models/modernbert/modeling_modernbert.py
|
| 216 |
+
class ApplyRotaryEmbUnpad(torch.autograd.Function):
|
| 217 |
+
@staticmethod
|
| 218 |
+
def forward(ctx, qkv, cos, sin, cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None):
|
| 219 |
+
# (total_nnz, 3, nheads, headdim)
|
| 220 |
+
qkv = qkv.contiguous()
|
| 221 |
+
total_nnz, _three, _nheads, headdim = qkv.shape
|
| 222 |
+
# We need qkv to be contiguous so that when we reshape to combine (3, nheads) dimensions,
|
| 223 |
+
# we get the same tensor
|
| 224 |
+
# qk = rearrange(qkv[:, :2], "b_s t h d -> b_s (t h) d")
|
| 225 |
+
qk = qkv[:, :2].view(total_nnz, -1, headdim)
|
| 226 |
+
apply_rotary(qk, cos, sin, seqlen_offsets=0, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, interleaved=False, inplace=True)
|
| 227 |
+
|
| 228 |
+
ctx.save_for_backward(cos, sin, cu_seqlens)
|
| 229 |
+
ctx.max_seqlen = max_seqlen
|
| 230 |
+
return qkv
|
| 231 |
+
|
| 232 |
+
@staticmethod
|
| 233 |
+
def backward(ctx, do):
|
| 234 |
+
cos, sin, cu_seqlens = ctx.saved_tensors
|
| 235 |
+
do = do.contiguous()
|
| 236 |
+
total_nnz, _three, _nheads, headdim = do.shape
|
| 237 |
+
# We need dqkv to be contiguous so that when we reshape to combine (3, nheads) dimensions,
|
| 238 |
+
# we get the same tensor
|
| 239 |
+
dqk = do[:, :2].view(total_nnz, -1, headdim)
|
| 240 |
+
apply_rotary(
|
| 241 |
+
dqk,
|
| 242 |
+
cos,
|
| 243 |
+
sin,
|
| 244 |
+
seqlen_offsets=0,
|
| 245 |
+
cu_seqlens=cu_seqlens,
|
| 246 |
+
max_seqlen=ctx.max_seqlen,
|
| 247 |
+
interleaved=False,
|
| 248 |
+
inplace=True,
|
| 249 |
+
conjugate=True,
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
return do, None, None, None, None, None, None
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
# from https://github.com/huggingface/transformers/blob/main/src/transformers/models/modernbert/modeling_modernbert.py
|
| 256 |
+
def apply_rotary_unpadded(qkv, cos, sin, cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None):
|
| 257 |
+
return ApplyRotaryEmbUnpad.apply(qkv, cos, sin, cu_seqlens, max_seqlen)
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
# from https://github.com/huggingface/transformers/blob/main/src/transformers/models/modernbert/modeling_modernbert.py
|
| 261 |
+
class UnpaddedRotaryEmbedding(RotaryEmbedding):
|
| 262 |
+
def __init__(self, dim: int, base: float = 10000.0, max_seqlen: Optional[int] = None):
|
| 263 |
+
super().__init__(dim=dim, base=base, device=None, interleaved=False)
|
| 264 |
+
self.max_seqlen = max_seqlen
|
| 265 |
+
|
| 266 |
+
def forward(self, qkv: torch.Tensor, cu_seqlens: torch.Tensor, max_seqlen: Optional[int] = None) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 267 |
+
if max_seqlen is not None:
|
| 268 |
+
self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
|
| 269 |
+
|
| 270 |
+
qkv = apply_rotary_unpadded(
|
| 271 |
+
qkv,
|
| 272 |
+
self._cos_cached,
|
| 273 |
+
self._sin_cached,
|
| 274 |
+
cu_seqlens=cu_seqlens,
|
| 275 |
+
max_seqlen=max_seqlen,
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
return qkv
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
class RotaryPositionalEmbeddings(nn.Module):
|
| 282 |
+
def __init__(self, config, theta: int):
|
| 283 |
+
super().__init__()
|
| 284 |
+
|
| 285 |
+
head_size = config.query_key_head_size
|
| 286 |
+
assert head_size % 2 == 0
|
| 287 |
+
max_seq_len = config.max_sequence_length
|
| 288 |
+
|
| 289 |
+
inv_freq = 1.0 / (theta ** (torch.arange(0, head_size, 2, dtype=torch.float32) / head_size))
|
| 290 |
+
pos = torch.arange(max_seq_len, dtype=torch.float32)
|
| 291 |
+
embedding = torch.einsum('n, d -> nd', pos, inv_freq)
|
| 292 |
+
embedding = torch.cat([embedding, embedding], dim=-1).unsqueeze(0)
|
| 293 |
+
self.register_buffer("cos_matrix", embedding.cos(), persistent=False)
|
| 294 |
+
self.register_buffer("sin_matrix", embedding.sin(), persistent=False)
|
| 295 |
+
|
| 296 |
+
def forward(self, x: torch.Tensor):
|
| 297 |
+
hidden_layer = x.float()
|
| 298 |
+
|
| 299 |
+
seq_len = x.shape[2]
|
| 300 |
+
|
| 301 |
+
cos_matrix = self.cos_matrix[:, None, :seq_len, :]
|
| 302 |
+
sin_matrix = self.sin_matrix[:, None, :seq_len, :]
|
| 303 |
+
|
| 304 |
+
x_rotate_half = torch.cat(
|
| 305 |
+
[
|
| 306 |
+
-hidden_layer[:, :, :, x.size(-1) // 2:],
|
| 307 |
+
hidden_layer[:, :, :, :x.size(-1) // 2]
|
| 308 |
+
],
|
| 309 |
+
dim=-1
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
out = hidden_layer * cos_matrix + x_rotate_half * sin_matrix
|
| 313 |
+
return out.type_as(x)
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
class MaskedSoftmax(torch.autograd.Function):
|
| 317 |
+
@staticmethod
|
| 318 |
+
def forward(ctx, x: torch.Tensor, mask: torch.BoolTensor, dim: int) -> torch.Tensor:
|
| 319 |
+
ctx.dim = dim
|
| 320 |
+
x.masked_fill_(mask, float('-inf'))
|
| 321 |
+
x = torch.softmax(x, ctx.dim)
|
| 322 |
+
x.masked_fill_(mask, 0.0)
|
| 323 |
+
ctx.save_for_backward(x)
|
| 324 |
+
return x
|
| 325 |
+
|
| 326 |
+
@staticmethod
|
| 327 |
+
def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None, None]:
|
| 328 |
+
output: torch.Tensor
|
| 329 |
+
|
| 330 |
+
output, = ctx.saved_tensors
|
| 331 |
+
inputGrad: torch.Tensor = _softmax_backward_data(grad_output, output, ctx.dim, output.dtype)
|
| 332 |
+
return inputGrad, None, None
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
class SelfAttention(nn.Module):
|
| 336 |
+
def __init__(self, config: GptBertConfig, layer_idx: int):
|
| 337 |
+
super().__init__()
|
| 338 |
+
|
| 339 |
+
self.config = config
|
| 340 |
+
self.layer_idx = layer_idx
|
| 341 |
+
|
| 342 |
+
self.d_qk = config.query_key_head_size
|
| 343 |
+
self.d_v = config.value_head_size
|
| 344 |
+
self.num_attention_heads = config.num_attention_heads
|
| 345 |
+
self.num_kv_heads = config.num_attention_heads
|
| 346 |
+
self.hidden_size = config.hidden_size
|
| 347 |
+
|
| 348 |
+
self.q_out_dim = self.d_qk * self.num_attention_heads
|
| 349 |
+
self.k_out_dim = self.d_qk * self.num_kv_heads
|
| 350 |
+
self.v_out_dim = self.d_v * self.num_kv_heads
|
| 351 |
+
|
| 352 |
+
self.qk_proj = MultiCastedLinearOrthoIn(self.hidden_size, [self.q_out_dim, self.k_out_dim], bias=False)
|
| 353 |
+
self.v_proj = CastedLinearIn(self.hidden_size, self.v_out_dim, bias=False)
|
| 354 |
+
self.out_proj = CastedLinearIn(self.d_v*self.num_attention_heads, self.hidden_size, bias=False)
|
| 355 |
+
|
| 356 |
+
self.pre_v_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False)
|
| 357 |
+
self.pre_qk_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False)
|
| 358 |
+
self.inter_norm = nn.LayerNorm(self.d_v * self.num_attention_heads, eps=config.layer_norm_eps, elementwise_affine=False)
|
| 359 |
+
self.q_norm = nn.LayerNorm(self.d_qk, eps=config.layer_norm_eps, elementwise_affine=False, bias=False)
|
| 360 |
+
self.k_norm = nn.LayerNorm(self.d_qk, eps=config.layer_norm_eps, elementwise_affine=False, bias=False)
|
| 361 |
+
self.k_scale = nn.Parameter(torch.ones(self.num_kv_heads, self.d_qk))
|
| 362 |
+
self.q_scale = nn.Parameter(torch.ones(self.num_attention_heads, self.d_qk))
|
| 363 |
+
|
| 364 |
+
self.attention_dropout = nn.Dropout(config.attention_dropout)
|
| 365 |
+
self.dropout = nn.Dropout(config.hidden_dropout)
|
| 366 |
+
|
| 367 |
+
theta = 160_000 if (layer_idx + 1) % config.local_global_ratio == 0 else 10_000
|
| 368 |
+
|
| 369 |
+
# Initialize rotary embeddings based on whether FlashAttention is available
|
| 370 |
+
if flash_attn_varlen_qkvpacked_func is not None:
|
| 371 |
+
self.rope_embedding = UnpaddedRotaryEmbedding(dim=self.d_qk, base=theta, max_seqlen=config.max_sequence_length)
|
| 372 |
+
else:
|
| 373 |
+
self.rope_embedding = RotaryPositionalEmbeddings(config, theta)
|
| 374 |
+
|
| 375 |
+
self.scale = 1.0 / math.sqrt(self.d_qk)
|
| 376 |
+
#self.lambdas = nn.Parameter(torch.tensor([0.5]))
|
| 377 |
+
|
| 378 |
+
self.sequence_length = config.max_sequence_length
|
| 379 |
+
self.is_causal = config.is_decoder
|
| 380 |
+
self.window_length = None
|
| 381 |
+
|
| 382 |
+
def set_window_length(self, window_length: int):
|
| 383 |
+
self.window_length = window_length
|
| 384 |
+
|
| 385 |
+
def _get_window_mask(self, query_length: int, key_length: int, device: torch.device):
|
| 386 |
+
"""Create and cache window attention mask."""
|
| 387 |
+
if self.is_causal:
|
| 388 |
+
mask = torch.ones(query_length, key_length, dtype=torch.bool, device=device)
|
| 389 |
+
mask = mask.tril().triu(diagonal=-self.window_length)
|
| 390 |
+
else:
|
| 391 |
+
mask = torch.ones(query_length, key_length, dtype=torch.bool, device=device)
|
| 392 |
+
mask = mask.tril(diagonal=self.window_length).triu(diagonal=-self.window_length)
|
| 393 |
+
return mask.view(1, 1, query_length, key_length)
|
| 394 |
+
|
| 395 |
+
def attention_operation(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, padding_mask: Optional[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 396 |
+
"""Standard attention computation with masking."""
|
| 397 |
+
batch_size, _, query_length, _ = query.size()
|
| 398 |
+
_, _, key_length, _ = key.size()
|
| 399 |
+
|
| 400 |
+
# Use cached window mask
|
| 401 |
+
with torch.no_grad():
|
| 402 |
+
window_mask = self._get_window_mask(query_length, key_length, query.device)
|
| 403 |
+
if padding_mask is not None:
|
| 404 |
+
attention_mask = padding_mask & window_mask
|
| 405 |
+
else:
|
| 406 |
+
attention_mask = window_mask
|
| 407 |
+
|
| 408 |
+
attention_scores = torch.bmm(query.flatten(0, 1), key.transpose(-1, -2).flatten(0, 1)) * self.scale # shape: [B*H, Q_T, K_T]
|
| 409 |
+
attention_scores = attention_scores.view(batch_size, self.num_attention_heads, query_length, key_length)
|
| 410 |
+
|
| 411 |
+
attention_probabilities = MaskedSoftmax.apply(attention_scores, ~attention_mask, -1)
|
| 412 |
+
attention_probabilities = self.attention_dropout(attention_probabilities)
|
| 413 |
+
|
| 414 |
+
output = torch.bmm(attention_probabilities.flatten(0, 1), value.flatten(0, 1))
|
| 415 |
+
output = output.view(batch_size, self.num_attention_heads, query_length, self.d_v)
|
| 416 |
+
|
| 417 |
+
return output
|
| 418 |
+
|
| 419 |
+
def forward(self, hidden_layer: torch.Tensor, qk_layer: torch.Tensor, v1: torch.Tensor | None, padding_info):
|
| 420 |
+
# Get original shape info
|
| 421 |
+
if flash_attn_varlen_qkvpacked_func is not None:
|
| 422 |
+
# Unpadded case
|
| 423 |
+
indices, cu_seqlens, max_seqlen = padding_info
|
| 424 |
+
total_seqlen = hidden_layer.size(0)
|
| 425 |
+
batch_size = cu_seqlens.size(0) - 1
|
| 426 |
+
else:
|
| 427 |
+
# Padded case
|
| 428 |
+
batch_size, seq_length = hidden_layer.size(0), hidden_layer.size(1)
|
| 429 |
+
|
| 430 |
+
hidden_layer = self.pre_v_norm(hidden_layer.float()).type_as(hidden_layer)
|
| 431 |
+
qk_layer = self.pre_qk_norm(qk_layer.float()).type_as(qk_layer)
|
| 432 |
+
|
| 433 |
+
query, key = self.qk_proj(qk_layer).tensor_split([self.q_out_dim], dim=-1)
|
| 434 |
+
value = self.v_proj(hidden_layer)
|
| 435 |
+
|
| 436 |
+
if flash_attn_varlen_qkvpacked_func is not None:
|
| 437 |
+
# Reshape for FlashAttention: (total_seqlen, num_heads, head_dim)
|
| 438 |
+
query = query.view(total_seqlen, self.num_attention_heads, self.d_qk)
|
| 439 |
+
key = key.view(total_seqlen, self.num_kv_heads, self.d_qk)
|
| 440 |
+
value = value.view(total_seqlen, self.num_kv_heads, self.d_v)
|
| 441 |
+
|
| 442 |
+
# Apply layer norm and scaling
|
| 443 |
+
query = ((self.q_scale + 1.0).unsqueeze(0) * self.q_norm(query.float())).type_as(query)
|
| 444 |
+
key = ((self.k_scale + 1.0).unsqueeze(0) * self.k_norm(key.float())).type_as(key)
|
| 445 |
+
|
| 446 |
+
# if v1 is None:
|
| 447 |
+
# v1 = value
|
| 448 |
+
# value = (1 - self.lambdas[0]) * value + self.lambdas[0] * v1
|
| 449 |
+
|
| 450 |
+
# Prepare qkv for FlashAttention
|
| 451 |
+
qkv = torch.stack([query, key, value], dim=1) # (total_seqlen, 3, num_heads, head_dim)
|
| 452 |
+
|
| 453 |
+
# Determine window size for local attention
|
| 454 |
+
if self.window_length is not None and self.window_length > 0:
|
| 455 |
+
if self.is_causal:
|
| 456 |
+
local_attention = (self.window_length - 1, 0)
|
| 457 |
+
else:
|
| 458 |
+
local_attention = (self.window_length - 1, self.window_length - 1)
|
| 459 |
+
else:
|
| 460 |
+
local_attention = (-1, -1)
|
| 461 |
+
|
| 462 |
+
# Apply FlashAttention
|
| 463 |
+
output = flash_attention_forward(
|
| 464 |
+
qkv,
|
| 465 |
+
self.rope_embedding,
|
| 466 |
+
cu_seqlens,
|
| 467 |
+
max_seqlen,
|
| 468 |
+
self.is_causal,
|
| 469 |
+
local_attention,
|
| 470 |
+
self.config.attention_dropout if self.training else 0.0,
|
| 471 |
+
self.config.deterministic_flash_attn
|
| 472 |
+
)
|
| 473 |
+
|
| 474 |
+
# Reshape output back
|
| 475 |
+
output = output.view(total_seqlen, self.d_v * self.num_attention_heads)
|
| 476 |
+
|
| 477 |
+
else:
|
| 478 |
+
# Standard attention path
|
| 479 |
+
query_length = query.size(1)
|
| 480 |
+
key_length = key.size(1)
|
| 481 |
+
|
| 482 |
+
query = query.reshape(batch_size, query_length, self.num_attention_heads, self.d_qk).transpose(1, 2)
|
| 483 |
+
key = key.reshape(batch_size, key_length, self.num_kv_heads, self.d_qk).transpose(1, 2)
|
| 484 |
+
value = value.reshape(batch_size, key_length, self.num_kv_heads, self.d_v).transpose(1, 2)
|
| 485 |
+
|
| 486 |
+
query = ((self.q_scale + 1.0).unsqueeze(1).unsqueeze(0) * self.q_norm(query.float())).type_as(query)
|
| 487 |
+
key = ((self.k_scale + 1.0).unsqueeze(1).unsqueeze(0) * self.k_norm(key.float())).type_as(key)
|
| 488 |
+
|
| 489 |
+
# if v1 is None:
|
| 490 |
+
# v1 = value
|
| 491 |
+
# else:
|
| 492 |
+
# value = (1 - self.lambdas[0]) * value + self.lambdas[0] * v1
|
| 493 |
+
|
| 494 |
+
# Apply rotary embeddings
|
| 495 |
+
query = self.rope_embedding(query)
|
| 496 |
+
key = self.rope_embedding(key)
|
| 497 |
+
|
| 498 |
+
output = self.attention_operation(query, key, value, padding_info)
|
| 499 |
+
output = output.transpose(1, 2).flatten(2, 3) # shape: [B, T, H*D]
|
| 500 |
+
|
| 501 |
+
output = self.inter_norm(output.float()).type_as(output)
|
| 502 |
+
output = self.out_proj(output)
|
| 503 |
+
output = self.dropout(output)
|
| 504 |
+
|
| 505 |
+
return output, v1
|
| 506 |
+
|
| 507 |
+
|
| 508 |
+
class FeedForward(nn.Module):
|
| 509 |
+
def __init__(self, config: GptBertConfig):
|
| 510 |
+
super().__init__()
|
| 511 |
+
self.pre_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False)
|
| 512 |
+
self.up_proj = MultiCastedLinearOrthoIn(config.hidden_size, [config.intermediate_size, config.intermediate_size], bias=False)
|
| 513 |
+
self.activation = GeGLU()
|
| 514 |
+
self.inter_norm = nn.LayerNorm(config.intermediate_size, eps=config.layer_norm_eps, elementwise_affine=False)
|
| 515 |
+
self.down_proj = CastedLinearIn(config.intermediate_size, config.hidden_size, bias=False)
|
| 516 |
+
self.dropout = nn.Dropout(config.hidden_dropout)
|
| 517 |
+
|
| 518 |
+
def forward(self, x: torch.Tensor):
|
| 519 |
+
x = self.pre_norm(x.float()).type_as(x)
|
| 520 |
+
x = self.up_proj(x)
|
| 521 |
+
x = self.activation(x)
|
| 522 |
+
x = self.inter_norm(x.float()).type_as(x)
|
| 523 |
+
x = self.down_proj(x)
|
| 524 |
+
x = self.dropout(x)
|
| 525 |
+
return x
|
| 526 |
+
|
| 527 |
+
|
| 528 |
+
class Layer(nn.Module):
|
| 529 |
+
|
| 530 |
+
def __init__(self, config: ModelConfig, layer_idx: int) -> None:
|
| 531 |
+
super().__init__()
|
| 532 |
+
|
| 533 |
+
self.attention: SelfAttention
|
| 534 |
+
self.mlp: FeedForward
|
| 535 |
+
|
| 536 |
+
self.attention = SelfAttention(config, layer_idx)
|
| 537 |
+
self.mlp = FeedForward(config)
|
| 538 |
+
self.lambdas_v = nn.Parameter(torch.tensor([1.0, 0.0]))
|
| 539 |
+
self.lambdas_qk = nn.Parameter(torch.tensor([1.0, 0.0]))
|
| 540 |
+
self.lambdas_mlp = nn.Parameter(torch.tensor([1.0, 1.0, 0.0]))
|
| 541 |
+
self.lambdas_out = nn.Parameter(torch.tensor([1.0, 1.0, 1.0, 0.0]))
|
| 542 |
+
|
| 543 |
+
def set_window_length(self, window_length: int) -> None:
|
| 544 |
+
self.attention.set_window_length(window_length)
|
| 545 |
+
|
| 546 |
+
def normalize_lambda(self, lambdas: torch.Tensor) -> torch.Tensor:
|
| 547 |
+
lambdas = lambdas / (lambdas.abs().mean() + 1e-6)
|
| 548 |
+
return lambdas
|
| 549 |
+
|
| 550 |
+
def forward(self, hidden_layer: torch.Tensor, embeddings: torch.Tensor, v1: torch.Tensor | None, padding_info):
|
| 551 |
+
output: torch.Tensor
|
| 552 |
+
|
| 553 |
+
lambdas_v = self.normalize_lambda(self.lambdas_v)
|
| 554 |
+
lambdas_qk = self.normalize_lambda(self.lambdas_qk)
|
| 555 |
+
lambdas_mlp = self.normalize_lambda(self.lambdas_mlp)
|
| 556 |
+
lambdas_out = self.normalize_lambda(self.lambdas_out)
|
| 557 |
+
|
| 558 |
+
v_layer = (lambdas_v[0] * hidden_layer) + (lambdas_v[1] * embeddings)
|
| 559 |
+
qk_layer = (lambdas_qk[0] * hidden_layer) + (lambdas_qk[1] * embeddings)
|
| 560 |
+
attention_output, v1 = self.attention(v_layer, qk_layer, v1, padding_info)
|
| 561 |
+
|
| 562 |
+
mlp_layer = (lambdas_mlp[0] * attention_output) + (lambdas_mlp[1] * hidden_layer) + (lambdas_mlp[2] * embeddings)
|
| 563 |
+
mlp_layer = self.mlp(mlp_layer)
|
| 564 |
+
|
| 565 |
+
output = (lambdas_out[0] * mlp_layer) + (lambdas_out[1] * attention_output) + (lambdas_out[2] * hidden_layer) + (lambdas_out[3] * embeddings)
|
| 566 |
+
|
| 567 |
+
return output, v1
|
| 568 |
+
|
| 569 |
+
|
| 570 |
+
class Encoder(nn.Module):
|
| 571 |
+
def __init__(self, config: GptBertConfig):
|
| 572 |
+
super().__init__()
|
| 573 |
+
self.layers = nn.ModuleList([Layer(config, i) for i in range(config.num_layers)])
|
| 574 |
+
self.local_global_ratio = config.local_global_ratio
|
| 575 |
+
|
| 576 |
+
def set_window_length(self, config: GptBertConfig):
|
| 577 |
+
for i, layer in enumerate(self.layers):
|
| 578 |
+
if (i + 1) % self.local_global_ratio == 0:
|
| 579 |
+
layer.set_window_length(config.global_window_length)
|
| 580 |
+
else:
|
| 581 |
+
layer.set_window_length(config.local_window_length)
|
| 582 |
+
|
| 583 |
+
def forward(self, hidden_layer: torch.Tensor, padding_info, output_hidden_states=False, checkpoint_activations=False):
|
| 584 |
+
hidden_layers = [hidden_layer] if output_hidden_states else None
|
| 585 |
+
v1 = None
|
| 586 |
+
embeddings = hidden_layer
|
| 587 |
+
|
| 588 |
+
for layer in self.layers:
|
| 589 |
+
if checkpoint_activations:
|
| 590 |
+
hidden_layer, v1 = torch.utils.checkpoint.checkpoint(layer, hidden_layer, embeddings, v1, padding_info, use_reentrant=True)
|
| 591 |
+
else:
|
| 592 |
+
hidden_layer, v1 = layer(hidden_layer, embeddings, v1, padding_info)
|
| 593 |
+
|
| 594 |
+
if output_hidden_states:
|
| 595 |
+
hidden_layers.append(hidden_layer)
|
| 596 |
+
|
| 597 |
+
return hidden_layer, hidden_layers
|
| 598 |
+
|
| 599 |
+
|
| 600 |
+
#
|
| 601 |
+
# HuggingFace wrappers
|
| 602 |
+
#
|
| 603 |
+
|
| 604 |
+
class GptBertPreTrainedModel(PreTrainedModel):
|
| 605 |
+
config_class = GptBertConfig
|
| 606 |
+
supports_gradient_checkpointing = True
|
| 607 |
+
_supports_flash_attn_2 = True
|
| 608 |
+
_supports_sdpa = True
|
| 609 |
+
_supports_flex_attn = False
|
| 610 |
+
|
| 611 |
+
def _init_weights(self, module):
|
| 612 |
+
std = math.sqrt(2.0 / (5.0 * self.hidden_size))
|
| 613 |
+
|
| 614 |
+
if isinstance(module, nn.Linear) or isinstance(module, CastedLinearIn):
|
| 615 |
+
nn.init.trunc_normal_(module.weight.data, mean=0.0, std=std, a=-2*std, b=2*std)
|
| 616 |
+
if module.bias is not None:
|
| 617 |
+
module.bias.data.zero_()
|
| 618 |
+
elif isinstance(module, nn.Embedding):
|
| 619 |
+
nn.init.trunc_normal_(module.weight.data, mean=0.0, std=std, a=-2*std, b=2*std)
|
| 620 |
+
elif isinstance(module, nn.LayerNorm):
|
| 621 |
+
module.bias.data.zero_()
|
| 622 |
+
module.weight.data.fill_(1.0)
|
| 623 |
+
|
| 624 |
+
|
| 625 |
+
class GptBertModel(GptBertPreTrainedModel):
|
| 626 |
+
def __init__(self, config: GptBertConfig, add_mlm_layer=False, **kwargs):
|
| 627 |
+
super().__init__(config, **kwargs)
|
| 628 |
+
self.config = config
|
| 629 |
+
self.hidden_size = config.hidden_size
|
| 630 |
+
|
| 631 |
+
self.embedding = Embedding(config)
|
| 632 |
+
self.encoder = Encoder(config)
|
| 633 |
+
self.classifier = LMClassifier(config, config.vocab_size) if add_mlm_layer else None
|
| 634 |
+
self.set_window_length(config)
|
| 635 |
+
self.gradient_checkpointing = False
|
| 636 |
+
self.post_init()
|
| 637 |
+
|
| 638 |
+
def set_window_length(self, config) -> None:
|
| 639 |
+
self.encoder.set_window_length(config)
|
| 640 |
+
|
| 641 |
+
def get_input_embeddings(self):
|
| 642 |
+
return self.embedding.word_embedding
|
| 643 |
+
|
| 644 |
+
def set_input_embeddings(self, value):
|
| 645 |
+
self.embedding.word_embedding = value
|
| 646 |
+
|
| 647 |
+
def get_contextualized_embeddings(
|
| 648 |
+
self,
|
| 649 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 650 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 651 |
+
output_hidden_states: Optional[bool] = None
|
| 652 |
+
):
|
| 653 |
+
if input_ids is not None:
|
| 654 |
+
input_shape = input_ids.size()
|
| 655 |
+
else:
|
| 656 |
+
raise ValueError("You have to specify input_ids")
|
| 657 |
+
|
| 658 |
+
batch_size, seq_length = input_shape
|
| 659 |
+
device = input_ids.device
|
| 660 |
+
|
| 661 |
+
if attention_mask is None:
|
| 662 |
+
attention_mask = torch.ones(batch_size, seq_length, dtype=torch.bool, device=device)
|
| 663 |
+
else:
|
| 664 |
+
attention_mask = attention_mask.bool()
|
| 665 |
+
|
| 666 |
+
if flash_attn_varlen_qkvpacked_func is not None:
|
| 667 |
+
if len(attention_mask.size()) != 2:
|
| 668 |
+
raise ValueError("Bare `attention_mask` med to dimensjoner støttes nå for FlashAttention.")
|
| 669 |
+
with torch.no_grad():
|
| 670 |
+
input_ids, indices, cu_seqlens, max_seqlen_in_batch = _unpad_input(input_ids, attention_mask)
|
| 671 |
+
padding_info = (indices, cu_seqlens, max_seqlen_in_batch)
|
| 672 |
+
else:
|
| 673 |
+
if len(attention_mask.size()) == 2:
|
| 674 |
+
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
| 675 |
+
elif len(attention_mask.size()) == 3:
|
| 676 |
+
attention_mask = attention_mask.unsqueeze(1)
|
| 677 |
+
padding_info = attention_mask
|
| 678 |
+
|
| 679 |
+
static_embeddings = self.embedding(input_ids)
|
| 680 |
+
|
| 681 |
+
original_dtype = static_embeddings.dtype
|
| 682 |
+
if torch.cuda.is_available() and torch.cuda.is_bf16_supported() and static_embeddings.dtype == torch.float32:
|
| 683 |
+
static_embeddings = static_embeddings.bfloat16()
|
| 684 |
+
|
| 685 |
+
last_layer, contextualized_embeddings = self.encoder(
|
| 686 |
+
static_embeddings,
|
| 687 |
+
padding_info,
|
| 688 |
+
output_hidden_states=output_hidden_states,
|
| 689 |
+
checkpoint_activations=self.gradient_checkpointing and self.training
|
| 690 |
+
)
|
| 691 |
+
|
| 692 |
+
last_layer = last_layer.to(original_dtype)
|
| 693 |
+
if output_hidden_states:
|
| 694 |
+
contextualized_embeddings = [layer.to(original_dtype) for layer in contextualized_embeddings]
|
| 695 |
+
|
| 696 |
+
# Pad output if using FlashAttention
|
| 697 |
+
if flash_attn_varlen_qkvpacked_func is not None:
|
| 698 |
+
last_layer = _pad_output(last_layer, indices, batch_size, seq_length)
|
| 699 |
+
if output_hidden_states:
|
| 700 |
+
contextualized_embeddings = [_pad_output(layer, indices, batch_size, seq_length) for layer in contextualized_embeddings]
|
| 701 |
+
else:
|
| 702 |
+
contextualized_embeddings = None
|
| 703 |
+
|
| 704 |
+
return last_layer, contextualized_embeddings
|
| 705 |
+
|
| 706 |
+
def forward(
|
| 707 |
+
self,
|
| 708 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 709 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 710 |
+
output_hidden_states: Optional[bool] = None,
|
| 711 |
+
output_attentions: Optional[bool] = None,
|
| 712 |
+
return_dict: Optional[bool] = None,
|
| 713 |
+
**kwargs
|
| 714 |
+
) -> Union[Tuple[torch.Tensor], BaseModelOutput]:
|
| 715 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 716 |
+
|
| 717 |
+
sequence_output, contextualized_embeddings = self.get_contextualized_embeddings(input_ids, attention_mask, output_hidden_states)
|
| 718 |
+
|
| 719 |
+
if not return_dict:
|
| 720 |
+
return (
|
| 721 |
+
sequence_output,
|
| 722 |
+
*([contextualized_embeddings] if output_hidden_states else [])
|
| 723 |
+
)
|
| 724 |
+
|
| 725 |
+
return BaseModelOutput(
|
| 726 |
+
last_hidden_state=sequence_output,
|
| 727 |
+
hidden_states=contextualized_embeddings if output_hidden_states else None
|
| 728 |
+
)
|
| 729 |
+
|
| 730 |
+
|
| 731 |
+
class GptBertForMaskedLM(GptBertModel):
|
| 732 |
+
_tied_weights_keys = ["classifier.emb2vocab.weight"]
|
| 733 |
+
|
| 734 |
+
def __init__(self, config: GptBertConfig, **kwargs):
|
| 735 |
+
super().__init__(config, add_mlm_layer=True, **kwargs)
|
| 736 |
+
|
| 737 |
+
def get_output_embeddings(self):
|
| 738 |
+
return self.classifier.emb2vocab.weight
|
| 739 |
+
|
| 740 |
+
def set_output_embeddings(self, new_embeddings):
|
| 741 |
+
self.classifier.emb2vocab.weight = new_embeddings
|
| 742 |
+
|
| 743 |
+
def forward(
|
| 744 |
+
self,
|
| 745 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 746 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 747 |
+
output_hidden_states: Optional[bool] = None,
|
| 748 |
+
return_dict: Optional[bool] = None,
|
| 749 |
+
labels: Optional[torch.LongTensor] = None,
|
| 750 |
+
**kwargs
|
| 751 |
+
) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
|
| 752 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 753 |
+
|
| 754 |
+
sequence_output, contextualized_embeddings = self.get_contextualized_embeddings(input_ids, attention_mask, output_hidden_states)
|
| 755 |
+
subword_prediction = self.classifier(sequence_output)
|
| 756 |
+
subword_prediction = 30 * torch.sigmoid(subword_prediction / 7.5)
|
| 757 |
+
|
| 758 |
+
masked_lm_loss = None
|
| 759 |
+
if labels is not None:
|
| 760 |
+
labels_flatten = labels[:, 1:].flatten()
|
| 761 |
+
subword_prediction_flatten = subword_prediction[:, :-1].flatten(0, 1)
|
| 762 |
+
masked_lm_loss = F.cross_entropy(subword_prediction_flatten, labels_flatten)
|
| 763 |
+
|
| 764 |
+
bos_logits = torch.zeros(subword_prediction.size(0), 1, self.config.vocab_size, dtype=subword_prediction.dtype, device=subword_prediction.device)
|
| 765 |
+
bos_logits[:, :, self.config.bos_token_id] = 1.0
|
| 766 |
+
subword_prediction = torch.cat([bos_logits, subword_prediction[:, :-1]], dim=1)
|
| 767 |
+
|
| 768 |
+
if not return_dict:
|
| 769 |
+
output = (
|
| 770 |
+
subword_prediction,
|
| 771 |
+
*([contextualized_embeddings] if output_hidden_states else [])
|
| 772 |
+
)
|
| 773 |
+
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
|
| 774 |
+
|
| 775 |
+
return MaskedLMOutput(
|
| 776 |
+
loss=masked_lm_loss,
|
| 777 |
+
logits=subword_prediction,
|
| 778 |
+
hidden_states=contextualized_embeddings if output_hidden_states else None
|
| 779 |
+
)
|
| 780 |
+
|
| 781 |
+
|
| 782 |
+
class GptBertForCausalLM(GptBertModel):
|
| 783 |
+
_tied_weights_keys = ["classifier.emb2vocab.weight"]
|
| 784 |
+
|
| 785 |
+
def __init__(self, config: GptBertConfig, **kwargs):
|
| 786 |
+
config.is_decoder = True
|
| 787 |
+
super().__init__(config, add_mlm_layer=True, **kwargs)
|
| 788 |
+
|
| 789 |
+
def get_output_embeddings(self):
|
| 790 |
+
return self.classifier.emb2vocab.weight
|
| 791 |
+
|
| 792 |
+
def set_output_embeddings(self, new_embeddings):
|
| 793 |
+
self.classifier.emb2vocab.weight = new_embeddings
|
| 794 |
+
|
| 795 |
+
def get_input_embeddings(self):
|
| 796 |
+
return self.embedding.word_embedding
|
| 797 |
+
|
| 798 |
+
def set_input_embeddings(self, value):
|
| 799 |
+
self.embedding.word_embedding = value
|
| 800 |
+
|
| 801 |
+
def set_decoder(self, decoder):
|
| 802 |
+
self.encoder = decoder
|
| 803 |
+
|
| 804 |
+
def get_decoder(self):
|
| 805 |
+
return self.encoder
|
| 806 |
+
|
| 807 |
+
def can_generate(self):
|
| 808 |
+
return True
|
| 809 |
+
|
| 810 |
+
def forward(
|
| 811 |
+
self,
|
| 812 |
+
input_ids: torch.LongTensor = None,
|
| 813 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 814 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 815 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
| 816 |
+
past_key_values: Optional[torch.Tensor] = None,
|
| 817 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 818 |
+
labels: Optional[torch.LongTensor] = None,
|
| 819 |
+
use_cache: Optional[bool] = None,
|
| 820 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 821 |
+
output_attentions: Optional[bool] = None,
|
| 822 |
+
output_hidden_states: Optional[bool] = None,
|
| 823 |
+
return_dict: Optional[bool] = None
|
| 824 |
+
) -> Union[Tuple, CausalLMOutput]:
|
| 825 |
+
|
| 826 |
+
assert inputs_embeds is None, "inputs_embeds is not supported for now"
|
| 827 |
+
assert past_key_values is None, "past_key_values is not supported for now"
|
| 828 |
+
assert not use_cache, "use_cache is not supported for now"
|
| 829 |
+
|
| 830 |
+
sequence_output, contextualized_embeddings = self.get_contextualized_embeddings(input_ids, attention_mask, output_hidden_states)
|
| 831 |
+
subword_prediction = self.classifier(sequence_output)
|
| 832 |
+
subword_prediction = 30 * torch.sigmoid(subword_prediction / 7.5)
|
| 833 |
+
|
| 834 |
+
causal_lm_loss = None
|
| 835 |
+
if labels is not None:
|
| 836 |
+
labels_flatten = labels[:, 1:].flatten()
|
| 837 |
+
subword_prediction_flatten = subword_prediction[:, :-1].flatten(0, 1)
|
| 838 |
+
causal_lm_loss = F.cross_entropy(subword_prediction_flatten, labels_flatten)
|
| 839 |
+
|
| 840 |
+
if not return_dict:
|
| 841 |
+
output = (
|
| 842 |
+
subword_prediction,
|
| 843 |
+
*([contextualized_embeddings] if output_hidden_states else [])
|
| 844 |
+
)
|
| 845 |
+
return ((causal_lm_loss,) + output) if masked_lm_loss is not None else output
|
| 846 |
+
|
| 847 |
+
return CausalLMOutput(
|
| 848 |
+
loss=causal_lm_loss,
|
| 849 |
+
logits=subword_prediction,
|
| 850 |
+
hidden_states=contextualized_embeddings if output_hidden_states else None
|
| 851 |
+
)
|
| 852 |
+
|
| 853 |
+
def prepare_inputs_for_generation(
|
| 854 |
+
self,
|
| 855 |
+
input_ids: torch.Tensor,
|
| 856 |
+
past_key_values: Optional[torch.Tensor] = None,
|
| 857 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 858 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 859 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 860 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 861 |
+
use_cache: bool = True,
|
| 862 |
+
num_logits_to_keep: Optional[int] = None,
|
| 863 |
+
**kwargs,
|
| 864 |
+
):
|
| 865 |
+
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
| 866 |
+
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
| 867 |
+
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
| 868 |
+
if past_key_values is not None:
|
| 869 |
+
if inputs_embeds is not None: # Exception 1
|
| 870 |
+
input_ids = input_ids[:, -cache_position.shape[0] :]
|
| 871 |
+
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
| 872 |
+
input_ids = input_ids[:, cache_position]
|
| 873 |
+
|
| 874 |
+
if attention_mask is not None and position_ids is None:
|
| 875 |
+
# create position_ids on the fly for batch generation
|
| 876 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
| 877 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
| 878 |
+
if past_key_values:
|
| 879 |
+
position_ids = position_ids[:, -input_ids.shape[1] :]
|
| 880 |
+
|
| 881 |
+
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
|
| 882 |
+
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
|
| 883 |
+
|
| 884 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
| 885 |
+
if inputs_embeds is not None and cache_position[0] == 0:
|
| 886 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
| 887 |
+
else:
|
| 888 |
+
model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
|
| 889 |
+
|
| 890 |
+
if num_logits_to_keep is not None:
|
| 891 |
+
model_inputs["num_logits_to_keep"] = num_logits_to_keep
|
| 892 |
+
|
| 893 |
+
model_inputs.update(
|
| 894 |
+
{
|
| 895 |
+
"position_ids": position_ids,
|
| 896 |
+
"cache_position": cache_position,
|
| 897 |
+
"past_key_values": past_key_values,
|
| 898 |
+
"use_cache": use_cache,
|
| 899 |
+
"attention_mask": attention_mask,
|
| 900 |
+
}
|
| 901 |
+
)
|
| 902 |
+
return model_inputs
|
| 903 |
+
|
| 904 |
+
|
| 905 |
+
class GptBertForSequenceClassification(GptBertModel):
|
| 906 |
+
_keys_to_ignore_on_load_missing = ["classifier.emb2vocab.weight", "classifier.emb2vocab.bias"]
|
| 907 |
+
_keys_to_ignore_on_load_unexpected = ["classifier.emb2vocab.weight", "classifier.emb2vocab.bias"]
|
| 908 |
+
|
| 909 |
+
def __init__(self, config: GptBertConfig, **kwargs):
|
| 910 |
+
super().__init__(config, add_mlm_layer=False, **kwargs)
|
| 911 |
+
|
| 912 |
+
self.num_labels = config.num_labels
|
| 913 |
+
self.classifier = Classifier(config, self.num_labels)
|
| 914 |
+
self.post_init()
|
| 915 |
+
|
| 916 |
+
def forward(
|
| 917 |
+
self,
|
| 918 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 919 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 920 |
+
output_hidden_states: Optional[bool] = None,
|
| 921 |
+
return_dict: Optional[bool] = None,
|
| 922 |
+
labels: Optional[torch.LongTensor] = None,
|
| 923 |
+
**kwargs
|
| 924 |
+
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
|
| 925 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 926 |
+
|
| 927 |
+
sequence_output, contextualized_embeddings = self.get_contextualized_embeddings(input_ids, attention_mask, output_hidden_states)
|
| 928 |
+
logits = self.classifier(sequence_output[:, 0, :])
|
| 929 |
+
|
| 930 |
+
loss = None
|
| 931 |
+
if labels is not None:
|
| 932 |
+
if self.config.problem_type is None:
|
| 933 |
+
if self.num_labels == 1:
|
| 934 |
+
self.config.problem_type = "regression"
|
| 935 |
+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
| 936 |
+
self.config.problem_type = "single_label_classification"
|
| 937 |
+
else:
|
| 938 |
+
self.config.problem_type = "multi_label_classification"
|
| 939 |
+
|
| 940 |
+
if self.config.problem_type == "regression":
|
| 941 |
+
loss_fct = nn.MSELoss()
|
| 942 |
+
if self.num_labels == 1:
|
| 943 |
+
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
| 944 |
+
else:
|
| 945 |
+
loss = loss_fct(logits, labels)
|
| 946 |
+
elif self.config.problem_type == "single_label_classification":
|
| 947 |
+
loss_fct = nn.CrossEntropyLoss()
|
| 948 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 949 |
+
elif self.config.problem_type == "multi_label_classification":
|
| 950 |
+
loss_fct = nn.BCEWithLogitsLoss()
|
| 951 |
+
loss = loss_fct(logits, labels)
|
| 952 |
+
|
| 953 |
+
if not return_dict:
|
| 954 |
+
output = (
|
| 955 |
+
logits,
|
| 956 |
+
*([contextualized_embeddings] if output_hidden_states else [])
|
| 957 |
+
)
|
| 958 |
+
return ((loss,) + output) if loss is not None else output
|
| 959 |
+
|
| 960 |
+
return SequenceClassifierOutput(
|
| 961 |
+
loss=loss,
|
| 962 |
+
logits=logits,
|
| 963 |
+
hidden_states=contextualized_embeddings if output_hidden_states else None
|
| 964 |
+
)
|
| 965 |
+
|
| 966 |
+
|
| 967 |
+
class GptBertForTokenClassification(GptBertModel):
|
| 968 |
+
_keys_to_ignore_on_load_missing = ["classifier.emb2vocab.weight", "classifier.emb2vocab.bias"]
|
| 969 |
+
_keys_to_ignore_on_load_unexpected = ["classifier.emb2vocab.weight", "classifier.emb2vocab.bias"]
|
| 970 |
+
|
| 971 |
+
def __init__(self, config: GptBertConfig, **kwargs):
|
| 972 |
+
super().__init__(config, add_mlm_layer=False, **kwargs)
|
| 973 |
+
|
| 974 |
+
self.num_labels = config.num_labels
|
| 975 |
+
self.classifier = Classifier(config, self.num_labels)
|
| 976 |
+
self.post_init()
|
| 977 |
+
|
| 978 |
+
def forward(
|
| 979 |
+
self,
|
| 980 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 981 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 982 |
+
output_hidden_states: Optional[bool] = None,
|
| 983 |
+
return_dict: Optional[bool] = None,
|
| 984 |
+
labels: Optional[torch.LongTensor] = None,
|
| 985 |
+
**kwargs
|
| 986 |
+
) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
|
| 987 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 988 |
+
|
| 989 |
+
sequence_output, contextualized_embeddings = self.get_contextualized_embeddings(input_ids, attention_mask, output_hidden_states)
|
| 990 |
+
logits = self.classifier(sequence_output)
|
| 991 |
+
|
| 992 |
+
loss = None
|
| 993 |
+
if labels is not None:
|
| 994 |
+
loss_fct = nn.CrossEntropyLoss()
|
| 995 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 996 |
+
|
| 997 |
+
if not return_dict:
|
| 998 |
+
output = (
|
| 999 |
+
logits,
|
| 1000 |
+
*([contextualized_embeddings] if output_hidden_states else []),
|
| 1001 |
+
*([attention_probs] if output_attentions else [])
|
| 1002 |
+
)
|
| 1003 |
+
return ((loss,) + output) if loss is not None else output
|
| 1004 |
+
|
| 1005 |
+
return TokenClassifierOutput(
|
| 1006 |
+
loss=loss,
|
| 1007 |
+
logits=logits,
|
| 1008 |
+
hidden_states=contextualized_embeddings if output_hidden_states else None,
|
| 1009 |
+
attentions=attention_probs if output_attentions else None
|
| 1010 |
+
)
|
| 1011 |
+
|
| 1012 |
+
|
| 1013 |
+
class GptBertForQuestionAnswering(GptBertModel):
|
| 1014 |
+
_keys_to_ignore_on_load_missing = ["classifier.emb2vocab.weight", "classifier.emb2vocab.bias"]
|
| 1015 |
+
_keys_to_ignore_on_load_unexpected = ["classifier.emb2vocab.weight", "classifier.emb2vocab.bias"]
|
| 1016 |
+
|
| 1017 |
+
def __init__(self, config: GptBertConfig, **kwargs):
|
| 1018 |
+
super().__init__(config, add_mlm_layer=False, **kwargs)
|
| 1019 |
+
|
| 1020 |
+
self.num_labels = config.num_labels
|
| 1021 |
+
self.classifier = Classifier(config, self.num_labels)
|
| 1022 |
+
self.post_init()
|
| 1023 |
+
|
| 1024 |
+
def forward(
|
| 1025 |
+
self,
|
| 1026 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 1027 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1028 |
+
output_hidden_states: Optional[bool] = None,
|
| 1029 |
+
return_dict: Optional[bool] = None,
|
| 1030 |
+
start_positions: Optional[torch.Tensor] = None,
|
| 1031 |
+
end_positions: Optional[torch.Tensor] = None,
|
| 1032 |
+
**kwargs
|
| 1033 |
+
) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
|
| 1034 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1035 |
+
|
| 1036 |
+
sequence_output, contextualized_embeddings = self.get_contextualized_embeddings(input_ids, attention_mask, output_hidden_states)
|
| 1037 |
+
logits = self.classifier(sequence_output)
|
| 1038 |
+
|
| 1039 |
+
start_logits, end_logits = logits.split(1, dim=-1)
|
| 1040 |
+
start_logits = start_logits.squeeze(-1).contiguous()
|
| 1041 |
+
end_logits = end_logits.squeeze(-1).contiguous()
|
| 1042 |
+
|
| 1043 |
+
total_loss = None
|
| 1044 |
+
if start_positions is not None and end_positions is not None:
|
| 1045 |
+
# If we are on multi-GPU, split add a dimension
|
| 1046 |
+
if len(start_positions.size()) > 1:
|
| 1047 |
+
start_positions = start_positions.squeeze(-1)
|
| 1048 |
+
if len(end_positions.size()) > 1:
|
| 1049 |
+
end_positions = end_positions.squeeze(-1)
|
| 1050 |
+
|
| 1051 |
+
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
| 1052 |
+
ignored_index = start_logits.size(1)
|
| 1053 |
+
start_positions = start_positions.clamp(0, ignored_index)
|
| 1054 |
+
end_positions = end_positions.clamp(0, ignored_index)
|
| 1055 |
+
|
| 1056 |
+
loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index)
|
| 1057 |
+
start_loss = loss_fct(start_logits, start_positions)
|
| 1058 |
+
end_loss = loss_fct(end_logits, end_positions)
|
| 1059 |
+
total_loss = (start_loss + end_loss) / 2
|
| 1060 |
+
|
| 1061 |
+
if not return_dict:
|
| 1062 |
+
output = (
|
| 1063 |
+
start_logits,
|
| 1064 |
+
end_logits,
|
| 1065 |
+
*([contextualized_embeddings] if output_hidden_states else [])
|
| 1066 |
+
)
|
| 1067 |
+
return ((total_loss,) + output) if total_loss is not None else output
|
| 1068 |
+
|
| 1069 |
+
return QuestionAnsweringModelOutput(
|
| 1070 |
+
loss=total_loss,
|
| 1071 |
+
start_logits=start_logits,
|
| 1072 |
+
end_logits=end_logits,
|
| 1073 |
+
hidden_states=contextualized_embeddings if output_hidden_states else None
|
| 1074 |
+
)
|
| 1075 |
+
|
| 1076 |
+
|
| 1077 |
+
class GptBertForMultipleChoice(GptBertModel):
|
| 1078 |
+
_keys_to_ignore_on_load_missing = ["classifier.emb2vocab.weight", "classifier.emb2vocab.bias"]
|
| 1079 |
+
_keys_to_ignore_on_load_unexpected = ["classifier.emb2vocab.weight", "classifier.emb2vocab.bias"]
|
| 1080 |
+
|
| 1081 |
+
def __init__(self, config: GptBertConfig, **kwargs):
|
| 1082 |
+
super().__init__(config, add_mlm_layer=False, **kwargs)
|
| 1083 |
+
|
| 1084 |
+
self.num_labels = getattr(config, "num_labels", 2)
|
| 1085 |
+
self.classifier = Classifier(config, self.num_labels)
|
| 1086 |
+
self.post_init()
|
| 1087 |
+
|
| 1088 |
+
def forward(
|
| 1089 |
+
self,
|
| 1090 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 1091 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1092 |
+
labels: Optional[torch.Tensor] = None,
|
| 1093 |
+
output_hidden_states: Optional[bool] = None,
|
| 1094 |
+
return_dict: Optional[bool] = None,
|
| 1095 |
+
**kwargs
|
| 1096 |
+
) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]:
|
| 1097 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1098 |
+
num_choices = input_ids.shape[1]
|
| 1099 |
+
|
| 1100 |
+
flat_input_ids = input_ids.view(-1, input_ids.size(-1))
|
| 1101 |
+
flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
|
| 1102 |
+
|
| 1103 |
+
sequence_output, contextualized_embeddings = self.get_contextualized_embeddings(flat_input_ids, flat_attention_mask, output_hidden_states)
|
| 1104 |
+
logits = self.classifier(sequence_output)
|
| 1105 |
+
reshaped_logits = logits.view(-1, num_choices)
|
| 1106 |
+
|
| 1107 |
+
loss = None
|
| 1108 |
+
if labels is not None:
|
| 1109 |
+
loss_fct = nn.CrossEntropyLoss()
|
| 1110 |
+
loss = loss_fct(reshaped_logits, labels)
|
| 1111 |
+
|
| 1112 |
+
if not return_dict:
|
| 1113 |
+
output = (
|
| 1114 |
+
reshaped_logits,
|
| 1115 |
+
*([contextualized_embeddings] if output_hidden_states else [])
|
| 1116 |
+
)
|
| 1117 |
+
return ((loss,) + output) if loss is not None else output
|
| 1118 |
+
|
| 1119 |
+
return MultipleChoiceModelOutput(
|
| 1120 |
+
loss=loss,
|
| 1121 |
+
logits=reshaped_logits,
|
| 1122 |
+
hidden_states=contextualized_embeddings if output_hidden_states else None
|
| 1123 |
+
)
|
pytorch_model.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:25dd6b44c778dc78c07a249d05ff982dd6c1d395c625fb09ea42f1e855610571
|
| 3 |
+
size 597586413
|
special_tokens_map.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"bos_token": "[CLS]", "eos_token": "[SEP]", "unk_token": "[UNK]", "sep_token": "[CLS]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]"}
|
tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tokenizer_config.json
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"tokenizer_class": "PreTrainedTokenizerFast",
|
| 3 |
+
"bos_token": "[CLS]",
|
| 4 |
+
"eos_token": "[SEP]",
|
| 5 |
+
"unk_token": "[UNK]",
|
| 6 |
+
"sep_token": "[SEP]",
|
| 7 |
+
"pad_token": "[PAD]",
|
| 8 |
+
"cls_token": "[CLS]",
|
| 9 |
+
"mask_token": "[MASK]"
|
| 10 |
+
}
|