MariaFjodorowa commited on
Commit
72ee8be
·
verified ·
1 Parent(s): 2f4f0e2

Upload folder using huggingface_hub

Browse files
README.md ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - ka
4
+ - kat
5
+ inference: false
6
+ tags:
7
+ - BERT
8
+ - HPLT
9
+ - encoder
10
+ - text2text-generation
11
+ license: apache-2.0
12
+ datasets:
13
+ - HPLT/HPLT3.0
14
+ ---
15
+
16
+ # HPLT v3.0 GPT-BERT for Georgian
17
+
18
+ <img src="https://hplt-project.org/_next/static/media/logo-hplt.d5e16ca5.svg" width=12.5%>
19
+
20
+ This is one of the monolingual language models trained as a third release by the [HPLT project](https://hplt-project.org/).
21
+ Our models follow the setup of [GPT-BERT](https://aclanthology.org/2024.conll-babylm.24/).
22
+
23
+ All the HPLT GPT-BERT models use the same hyper-parameters:
24
+ - hidden size: 640
25
+ - attention heads: 10
26
+ - layers: 24
27
+ - vocabulary size: 32768
28
+
29
+ Every model uses its own tokenizer trained on language-specific HPLT data.
30
+
31
+ [The training code](https://github.com/ltgoslo/NorBERT/tree/main/norbert4).
32
+
33
+ ## Example usage (bidirectional encoding)
34
+
35
+ This model currently needs a custom wrapper from `modeling_gptbert.py`, you should therefore load the model with `trust_remote_code=True`.
36
+
37
+ ```python
38
+ import torch
39
+ from transformers import AutoTokenizer, AutoModelForMaskedLM
40
+
41
+ # Import model
42
+ tokenizer = AutoTokenizer.from_pretrained(
43
+ "HPLT/hplt_gpt_bert_base_3_0_kat_Geor",
44
+ )
45
+ model = AutoModelForMaskedLM.from_pretrained(
46
+ "HPLT/hplt_gpt_bert_base_3_0_kat_Geor",
47
+ trust_remote_code=True,
48
+ use_safetensors=False,
49
+ )
50
+ model = model.eval()
51
+ input_text = f"Norwegian is a {tokenizer.mask_token} Germanic language"
52
+ print(input_text)
53
+ # Tokenize text (with a mask token inside)
54
+ input_text = tokenizer(
55
+ input_text,
56
+ return_tensors="pt",
57
+ )
58
+ # Inference
59
+ with torch.no_grad():
60
+ output_p = model(**input_text)
61
+
62
+ # Unmask the text
63
+ output_text = torch.where(
64
+ input_text.input_ids == tokenizer.mask_token_id,
65
+ output_p.logits.argmax(-1),
66
+ input_text.input_ids
67
+ )
68
+
69
+ # Decoding; should output: 'Norwegian is a North Germanic language'
70
+ print(tokenizer.decode(output_text[0].tolist()))
71
+ ```
72
+
73
+ ## Example usage (text generation)
74
+
75
+ GPT-BERT also supports unidirectional text decoding, it can generate text like any other GPT model:
76
+
77
+ ```python
78
+ import torch
79
+ from transformers import AutoTokenizer, AutoModelForCausalLM
80
+
81
+ tokenizer = AutoTokenizer.from_pretrained(
82
+ "HPLT/hplt_gpt_bert_base_3_0_kat_Geor",
83
+ )
84
+ model = AutoModelForCausalLM.from_pretrained(
85
+ "HPLT/hplt_gpt_bert_base_3_0_kat_Geor",
86
+ trust_remote_code=True,
87
+ use_safetensors=False,
88
+ )
89
+ text = f"The Norwegian Constitution"
90
+ print(text, flush=True)
91
+ # Define tokens that should end the generation
92
+ eos_token_ids = [
93
+ token_id
94
+ for token_id in range(tokenizer.vocab_size)
95
+ if '.' in tokenizer.decode([token_id])
96
+ ]
97
+
98
+ # Generation function
99
+ @torch.no_grad()
100
+ def generate(text):
101
+ input_ids = tokenizer(text, return_tensors='pt').input_ids
102
+ prediction = model.generate(
103
+ input_ids,
104
+ max_new_tokens=63,
105
+ do_sample=False,
106
+ eos_token_id=eos_token_ids,
107
+ )
108
+ return tokenizer.decode(prediction[0]).strip()
109
+
110
+ # Example usage, should output '[CLS]The Norwegian Constitution[SEP]is a document that defines the rights and responsibilities of the Norwegian people and their representatives.'
111
+ print(generate(text), flush=True)
112
+ ```
113
+
114
+ The following classes are currently implemented: `AutoModel`, `AutoModelForMaskedLM`, `AutoModelForCausalLM`, `AutoModelForSequenceClassification`, `AutoModelForTokenClassification`, `AutoModelForQuestionAnswering` and `AutoModeltForMultipleChoice`.
115
+
116
+ ## Intermediate checkpoints
117
+
118
+ We are releasing 10 intermediate checkpoints for each model at intervals of every 3125 training steps in separate branches. The naming convention is `stepXXX`: for example, `step18750`.
119
+
120
+ You can load a specific model revision with `transformers` using the argument `revision`:
121
+ ```python
122
+ model = AutoModelForSeq2SeqLM.from_pretrained("HPLT/hplt_gpt_bert_base_3_0_kat_Geor", revision="step21875", trust_remote_code=True)
123
+ ```
124
+
125
+ You can access all the revisions for the models with the following code:
126
+ ```python
127
+ from huggingface_hub import list_repo_refs
128
+ out = list_repo_refs("HPLT/hplt_gpt_bert_base_3_0_kat_Geor")
129
+ print([b.name for b in out.branches])
130
+ ```
131
+
132
+ ## Cite us
133
+
134
+ ```bibtex
135
+ @inproceedings{charpentier-samuel-2024-bert,
136
+ title = "{GPT} or {BERT}: why not both?",
137
+ author = "Charpentier, Lucas Georges Gabriel and
138
+ Samuel, David",
139
+ booktitle = "The 2nd BabyLM Challenge at the 28th Conference on Computational Natural Language Learning",
140
+ month = nov,
141
+ year = "2024",
142
+ address = "Miami, FL, USA",
143
+ publisher = "Association for Computational Linguistics",
144
+ url = "https://aclanthology.org/2024.conll-babylm.24/",
145
+ pages = "262--283"
146
+ }
147
+ ```
148
+
149
+ ```bibtex
150
+ @misc{oepen2025hplt30largescalemultilingual,
151
+ title={{HPLT 3.0}: {V}ery Large-Scale Multilingual Resources for {LLM} and {MT}. Mono- and Bi-lingual Data, Multilingual Evaluation, and Pre-Trained Models},
152
+ author={Stephan Oepen and Nikolay Arefev and Mikko Aulamo and Marta Bañón and Maja Buljan and Laurie Burchell and Lucas Charpentier and Pinzhen Chen and Mariia Fedorova and Ona de Gibert and Barry Haddow and Jan Hajič and Jindřich Helcl and Andrey Kutuzov and Veronika Laippala and Zihao Li and Risto Luukkonen and Bhavitvya Malik and Vladislav Mikhailov and Amanda Myntti and Dayyán O'Brien and Lucie Poláková and Sampo Pyysalo and Gema Ramírez Sánchez and Janine Siewert and Pavel Stepachev and Jörg Tiedemann and Teemu Vahtola and Dušan Variš and Fedor Vitiugin and Tea Vojtěchová and Jaume Zaragoza},
153
+ year={2025},
154
+ eprint={2511.01066},
155
+ archivePrefix={arXiv},
156
+ primaryClass={cs.CL},
157
+ url={https://arxiv.org/abs/2511.01066},
158
+ }
159
+ ```
160
+ [![arXiv](https://img.shields.io/badge/arXiv-2410.24159-b31b1b.svg)](https://arxiv.org/abs/2410.24159)
161
+ [![arXiv](https://img.shields.io/badge/arXiv-2511.01066-b31b1b.svg)](https://arxiv.org/abs/2511.01066)
__init__.py ADDED
File without changes
config.json ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "GptBertForMaskedLM"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_gptbert.GptBertConfig",
7
+ "AutoModel": "modeling_gptbert.GptBertModel",
8
+ "AutoModelForCausalLM": "modeling_gptbert.GptBertForCausalLM",
9
+ "AutoModelForMaskedLM": "modeling_gptbert.GptBertForMaskedLM",
10
+ "AutoModelForSequenceClassification": "modeling_gptbert.GptBertForSequenceClassification",
11
+ "AutoModelForTokenClassification": "modeling_gptbert.GptBertForTokenClassification",
12
+ "AutoModelForQuestionAnswering": "modeling_gptbert.GptBertForQuestionAnswering",
13
+ "AutoModelForMultipleChoice": "modeling_gptbert.GptBertForMultipleChoice"
14
+ },
15
+ "unk_token_id": 1,
16
+ "bos_token_id": 2,
17
+ "eos_token_id": 3,
18
+ "pad_token_id": 0,
19
+ "mask_token_id": 4,
20
+ "hidden_size": 640,
21
+ "intermediate_size": 1664,
22
+ "max_sequence_length": 16384,
23
+ "num_layers": 24,
24
+ "attention_dropout": 0.0,
25
+ "hidden_dropout": 0.0,
26
+ "embedding_dropout": 0.1,
27
+ "classifier_dropout": 0.2,
28
+ "layer_norm_eps": 1e-07,
29
+ "query_key_head_size": 64,
30
+ "value_head_size": 64,
31
+ "num_attention_heads": 10,
32
+ "rope_theta": 160000,
33
+ "vocab_size": 32768,
34
+ "local_global_ratio": 4,
35
+ "global_window_length": 8192,
36
+ "local_window_length": 256,
37
+ "deterministic_flash_attn": false,
38
+ "use_cache": false
39
+ }
configuration_gptbert.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from pathlib import Path
5
+ import copy
6
+ from transformers.configuration_utils import PretrainedConfig
7
+
8
+
9
+ class GptBertConfig(PretrainedConfig):
10
+
11
+ def __init__(
12
+ self,
13
+ config_file: Path | str | None = None,
14
+ **kwargs
15
+ ):
16
+ super().__init__(**kwargs)
17
+ self.model = "norbert4"
18
+
19
+ if config_file is not None:
20
+ if type(config_file) is str:
21
+ config_file = Path(config_file)
22
+ assert type(config_file) is not Path, "The config_file should either be a Path or str"
23
+ with config_file.open("r") as file:
24
+ config = json.load(file)
25
+
26
+ for attr, value in config.items():
27
+ if isinstance(value, str):
28
+ value = value.lower()
29
+ setattr(self, attr, value)
30
+
31
+ for attr, value in kwargs.items():
32
+ if isinstance(value, str):
33
+ value = value.lower()
34
+ setattr(self, attr, value)
modeling_gptbert.py ADDED
@@ -0,0 +1,1123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.nn import functional as F
6
+ from torch import _softmax_backward_data as _softmax_backward_data
7
+
8
+ from functools import partial, lru_cache
9
+
10
+ from .configuration_gptbert import GptBertConfig
11
+ from transformers.modeling_utils import PreTrainedModel
12
+ from transformers.activations import gelu_new
13
+ from transformers.utils import is_flash_attn_2_available, logging
14
+ from transformers.modeling_outputs import (
15
+ MaskedLMOutput,
16
+ MultipleChoiceModelOutput,
17
+ QuestionAnsweringModelOutput,
18
+ SequenceClassifierOutput,
19
+ TokenClassifierOutput,
20
+ BaseModelOutput,
21
+ CausalLMOutput
22
+ )
23
+ import math
24
+ from typing import TYPE_CHECKING, Optional, Union, Tuple, List
25
+
26
+ logger = logging.get_logger(__name__)
27
+
28
+
29
+ # Workaround for transformers < 4.36.0 check_imports issue
30
+ # See: https://github.com/huggingface/transformers/issues/28459
31
+ try:
32
+ if is_flash_attn_2_available():
33
+ from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
34
+ from flash_attn.layers.rotary import RotaryEmbedding
35
+ from flash_attn.ops.triton.rotary import apply_rotary
36
+ else:
37
+ flash_attn_varlen_qkvpacked_func, RotaryEmbedding, apply_rotary = None, object, None
38
+ logger.warning_once(
39
+ "NorBERT4 støtter FlashAttention, men det er ikke funnet i miljøet ditt. Du bør vurdere å oppdatere miljøet ditt for å få raskere og mindre minnekrevende behandling."
40
+ )
41
+ except ImportError:
42
+ flash_attn_varlen_qkvpacked_func, RotaryEmbedding, apply_rotary = None, object, None
43
+ logger.warning_once(
44
+ "NorBERT4 støtter FlashAttention, men det er ikke funnet i miljøet ditt. Du bør vurdere å oppdatere miljøet ditt for å få raskere og mindre minnekrevende behandling."
45
+ )
46
+
47
+
48
+ # from https://github.com/huggingface/transformers/blob/main/src/transformers/models/modernbert/modeling_modernbert.py
49
+ @torch.compiler.disable()
50
+ def _unpad_input(input_ids: torch.Tensor, attention_mask: torch.Tensor):
51
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
52
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
53
+ max_seqlen_in_batch = int(seqlens_in_batch.max().item())
54
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
55
+
56
+ if input_ids.dim() == 2:
57
+ unpadded_inputs = input_ids.flatten()[indices]
58
+ else:
59
+ batch_size, sequence_length, *rest = input_ids.shape
60
+ shape = batch_size * sequence_length
61
+ unpadded_inputs = input_ids.view(shape, *rest)[indices]
62
+
63
+ return unpadded_inputs, indices, cu_seqlens, max_seqlen_in_batch
64
+
65
+
66
+ # from https://github.com/huggingface/transformers/blob/main/src/transformers/models/modernbert/modeling_modernbert.py
67
+ def _pad_output(input_ids: torch.Tensor, indices: torch.Tensor, batch_size: int, sequence_length: int) -> torch.Tensor:
68
+ if input_ids.dim() == 1:
69
+ output = torch.zeros(batch_size * sequence_length, dtype=input_ids.dtype, device=input_ids.device)
70
+ output[indices] = input_ids
71
+ padded_inputs = output.view(batch_size, sequence_length)
72
+ else:
73
+ _, *rest = input_ids.shape
74
+ output = torch.zeros(batch_size * sequence_length, *rest, dtype=input_ids.dtype, device=input_ids.device)
75
+ output[indices] = input_ids
76
+ padded_inputs = output.view(batch_size, sequence_length, *rest)
77
+
78
+ return padded_inputs
79
+
80
+
81
+ class CastedLinear(nn.Linear):
82
+ def __init__(self, in_features, out_features, bias):
83
+ super().__init__(in_features, out_features, bias=bias)
84
+
85
+ def forward(self, x):
86
+ return F.linear(x, self.weight.type_as(x), bias=self.bias.type_as(x) if self.bias is not None else None)
87
+
88
+
89
+ class CastedLinearIn(nn.Linear):
90
+ def __init__(self, in_features, out_features, bias):
91
+ super().__init__(in_features, out_features, bias=bias)
92
+ self.scale = nn.Parameter(torch.ones(in_features))
93
+
94
+ def forward(self, x):
95
+ return F.linear(x, (self.weight * (self.scale + 1.0).unsqueeze(0)).type_as(x), bias=self.bias.type_as(x) if self.bias is not None else None)
96
+
97
+
98
+ class MultiCastedLinearOrthoIn(nn.Module):
99
+ def __init__(self, in_features, out_features, bias):
100
+ super().__init__()
101
+
102
+ self.in_features = in_features
103
+ self.out_features = out_features
104
+
105
+ self.weights = nn.ParameterList()
106
+ for out_feature in out_features:
107
+ self.weights.append(nn.Parameter(torch.empty((out_feature, in_features))))
108
+
109
+ if bias:
110
+ self.bias = nn.Parameter(torch.zeros(sum(out_features)))
111
+ else:
112
+ self.bias = self.register_parameter("bias", None)
113
+
114
+ self.scale = nn.Parameter(torch.ones(in_features))
115
+
116
+ def forward(self, x):
117
+ return F.linear(x, (torch.cat([weight for weight in self.weights], dim=0) * (self.scale + 1.0).unsqueeze(0)).type_as(x), bias=self.bias.type_as(x) if self.bias is not None else None)
118
+
119
+
120
+ class GeGLU(nn.Module):
121
+ def forward(self, x):
122
+ x, gate = x.chunk(2, dim=-1)
123
+ return x * gelu_new(gate)
124
+
125
+
126
+ class Embedding(nn.Module):
127
+ def __init__(self, config: GptBertConfig):
128
+ super().__init__()
129
+
130
+ self.word_embedding = nn.Embedding(config.vocab_size, config.hidden_size)
131
+ self.word_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False, bias=False)
132
+ self.word_scale = nn.Parameter(torch.zeros(config.hidden_size))
133
+ self.dropout = nn.Dropout(config.embedding_dropout)
134
+
135
+ def forward(self, input_ids: torch.Tensor):
136
+ word_embedding = self.word_embedding(input_ids)
137
+ word_embedding = self.word_norm(word_embedding)
138
+ word_embedding = word_embedding * (self.word_scale + 1.0)
139
+
140
+ return self.dropout(word_embedding)
141
+
142
+
143
+ class LMClassifier(nn.Module):
144
+ def __init__(self, config: GptBertConfig, n_labels: int):
145
+ super().__init__()
146
+
147
+ self.pre_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False)
148
+ self.projection = CastedLinearIn(config.hidden_size, config.hidden_size, bias=False)
149
+ self.post_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False)
150
+ self.emb2vocab = CastedLinearIn(config.hidden_size, n_labels, bias=True)
151
+
152
+ def forward(self, x: torch.Tensor):
153
+ x = self.pre_norm(x.float()).type_as(x)
154
+ x = self.projection(x)
155
+ x = gelu_new(x)
156
+ x = self.post_norm(x.float()).type_as(x)
157
+ x = self.emb2vocab(x)
158
+ return x
159
+
160
+
161
+ class Classifier(nn.Module):
162
+ def __init__(self, config: GptBertConfig, n_labels: int):
163
+ super().__init__()
164
+
165
+ self.pre_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False)
166
+ self.projection = CastedLinearIn(config.hidden_size, config.hidden_size, bias=False)
167
+ self.post_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False)
168
+ self.dropout = nn.Dropout(config.classifier_dropout)
169
+ self.output_projection = CastedLinearIn(config.hidden_size, n_labels, bias=True)
170
+
171
+ def forward(self, x: torch.Tensor):
172
+ x = self.pre_norm(x.float()).type_as(x)
173
+ x = self.projection(x)
174
+ x = gelu_new(x)
175
+ x = self.post_norm(x.float()).type_as(x)
176
+ x = self.dropout(x)
177
+ x = self.output_projection(x)
178
+ return x
179
+
180
+
181
+ # from https://github.com/huggingface/transformers/blob/main/src/transformers/models/modernbert/modeling_modernbert.py
182
+ def flash_attention_forward(qkv: torch.Tensor, rotary_emb: UnpaddedRotaryEmbedding, cu_seqlens: torch.Tensor, max_seqlen: int, causal: bool, local_attention: Tuple[int, int], dropout_p: float, deterministic: bool, target_dtype: torch.dtype = torch.bfloat16, **_kwargs):
183
+ qkv = rotary_emb(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)
184
+
185
+ convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16)
186
+ if convert_dtype:
187
+ # FA2 implementation only supports fp16 and bf16. If FA2 is supported,
188
+ # bfloat16 must be supported as of FA2 2.5.7. (Turing GPUs not supported)
189
+ orig_dtype = qkv.dtype
190
+ qkv = qkv.to(target_dtype)
191
+
192
+ attn = flash_attn_varlen_qkvpacked_func(
193
+ qkv,
194
+ cu_seqlens=cu_seqlens,
195
+ max_seqlen=max_seqlen,
196
+ dropout_p=dropout_p,
197
+ deterministic=deterministic,
198
+ window_size=local_attention,
199
+ causal=False
200
+ )
201
+ attn = attn.to(orig_dtype) # type: ignore
202
+ else:
203
+ attn = flash_attn_varlen_qkvpacked_func(
204
+ qkv,
205
+ cu_seqlens=cu_seqlens,
206
+ max_seqlen=max_seqlen,
207
+ dropout_p=dropout_p,
208
+ deterministic=deterministic,
209
+ window_size=local_attention,
210
+ causal=False
211
+ )
212
+ return attn
213
+
214
+
215
+ # from https://github.com/huggingface/transformers/blob/main/src/transformers/models/modernbert/modeling_modernbert.py
216
+ class ApplyRotaryEmbUnpad(torch.autograd.Function):
217
+ @staticmethod
218
+ def forward(ctx, qkv, cos, sin, cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None):
219
+ # (total_nnz, 3, nheads, headdim)
220
+ qkv = qkv.contiguous()
221
+ total_nnz, _three, _nheads, headdim = qkv.shape
222
+ # We need qkv to be contiguous so that when we reshape to combine (3, nheads) dimensions,
223
+ # we get the same tensor
224
+ # qk = rearrange(qkv[:, :2], "b_s t h d -> b_s (t h) d")
225
+ qk = qkv[:, :2].view(total_nnz, -1, headdim)
226
+ apply_rotary(qk, cos, sin, seqlen_offsets=0, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, interleaved=False, inplace=True)
227
+
228
+ ctx.save_for_backward(cos, sin, cu_seqlens)
229
+ ctx.max_seqlen = max_seqlen
230
+ return qkv
231
+
232
+ @staticmethod
233
+ def backward(ctx, do):
234
+ cos, sin, cu_seqlens = ctx.saved_tensors
235
+ do = do.contiguous()
236
+ total_nnz, _three, _nheads, headdim = do.shape
237
+ # We need dqkv to be contiguous so that when we reshape to combine (3, nheads) dimensions,
238
+ # we get the same tensor
239
+ dqk = do[:, :2].view(total_nnz, -1, headdim)
240
+ apply_rotary(
241
+ dqk,
242
+ cos,
243
+ sin,
244
+ seqlen_offsets=0,
245
+ cu_seqlens=cu_seqlens,
246
+ max_seqlen=ctx.max_seqlen,
247
+ interleaved=False,
248
+ inplace=True,
249
+ conjugate=True,
250
+ )
251
+
252
+ return do, None, None, None, None, None, None
253
+
254
+
255
+ # from https://github.com/huggingface/transformers/blob/main/src/transformers/models/modernbert/modeling_modernbert.py
256
+ def apply_rotary_unpadded(qkv, cos, sin, cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None):
257
+ return ApplyRotaryEmbUnpad.apply(qkv, cos, sin, cu_seqlens, max_seqlen)
258
+
259
+
260
+ # from https://github.com/huggingface/transformers/blob/main/src/transformers/models/modernbert/modeling_modernbert.py
261
+ class UnpaddedRotaryEmbedding(RotaryEmbedding):
262
+ def __init__(self, dim: int, base: float = 10000.0, max_seqlen: Optional[int] = None):
263
+ super().__init__(dim=dim, base=base, device=None, interleaved=False)
264
+ self.max_seqlen = max_seqlen
265
+
266
+ def forward(self, qkv: torch.Tensor, cu_seqlens: torch.Tensor, max_seqlen: Optional[int] = None) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
267
+ if max_seqlen is not None:
268
+ self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
269
+
270
+ qkv = apply_rotary_unpadded(
271
+ qkv,
272
+ self._cos_cached,
273
+ self._sin_cached,
274
+ cu_seqlens=cu_seqlens,
275
+ max_seqlen=max_seqlen,
276
+ )
277
+
278
+ return qkv
279
+
280
+
281
+ class RotaryPositionalEmbeddings(nn.Module):
282
+ def __init__(self, config, theta: int):
283
+ super().__init__()
284
+
285
+ head_size = config.query_key_head_size
286
+ assert head_size % 2 == 0
287
+ max_seq_len = config.max_sequence_length
288
+
289
+ inv_freq = 1.0 / (theta ** (torch.arange(0, head_size, 2, dtype=torch.float32) / head_size))
290
+ pos = torch.arange(max_seq_len, dtype=torch.float32)
291
+ embedding = torch.einsum('n, d -> nd', pos, inv_freq)
292
+ embedding = torch.cat([embedding, embedding], dim=-1).unsqueeze(0)
293
+ self.register_buffer("cos_matrix", embedding.cos(), persistent=False)
294
+ self.register_buffer("sin_matrix", embedding.sin(), persistent=False)
295
+
296
+ def forward(self, x: torch.Tensor):
297
+ hidden_layer = x.float()
298
+
299
+ seq_len = x.shape[2]
300
+
301
+ cos_matrix = self.cos_matrix[:, None, :seq_len, :]
302
+ sin_matrix = self.sin_matrix[:, None, :seq_len, :]
303
+
304
+ x_rotate_half = torch.cat(
305
+ [
306
+ -hidden_layer[:, :, :, x.size(-1) // 2:],
307
+ hidden_layer[:, :, :, :x.size(-1) // 2]
308
+ ],
309
+ dim=-1
310
+ )
311
+
312
+ out = hidden_layer * cos_matrix + x_rotate_half * sin_matrix
313
+ return out.type_as(x)
314
+
315
+
316
+ class MaskedSoftmax(torch.autograd.Function):
317
+ @staticmethod
318
+ def forward(ctx, x: torch.Tensor, mask: torch.BoolTensor, dim: int) -> torch.Tensor:
319
+ ctx.dim = dim
320
+ x.masked_fill_(mask, float('-inf'))
321
+ x = torch.softmax(x, ctx.dim)
322
+ x.masked_fill_(mask, 0.0)
323
+ ctx.save_for_backward(x)
324
+ return x
325
+
326
+ @staticmethod
327
+ def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None, None]:
328
+ output: torch.Tensor
329
+
330
+ output, = ctx.saved_tensors
331
+ inputGrad: torch.Tensor = _softmax_backward_data(grad_output, output, ctx.dim, output.dtype)
332
+ return inputGrad, None, None
333
+
334
+
335
+ class SelfAttention(nn.Module):
336
+ def __init__(self, config: GptBertConfig, layer_idx: int):
337
+ super().__init__()
338
+
339
+ self.config = config
340
+ self.layer_idx = layer_idx
341
+
342
+ self.d_qk = config.query_key_head_size
343
+ self.d_v = config.value_head_size
344
+ self.num_attention_heads = config.num_attention_heads
345
+ self.num_kv_heads = config.num_attention_heads
346
+ self.hidden_size = config.hidden_size
347
+
348
+ self.q_out_dim = self.d_qk * self.num_attention_heads
349
+ self.k_out_dim = self.d_qk * self.num_kv_heads
350
+ self.v_out_dim = self.d_v * self.num_kv_heads
351
+
352
+ self.qk_proj = MultiCastedLinearOrthoIn(self.hidden_size, [self.q_out_dim, self.k_out_dim], bias=False)
353
+ self.v_proj = CastedLinearIn(self.hidden_size, self.v_out_dim, bias=False)
354
+ self.out_proj = CastedLinearIn(self.d_v*self.num_attention_heads, self.hidden_size, bias=False)
355
+
356
+ self.pre_v_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False)
357
+ self.pre_qk_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False)
358
+ self.inter_norm = nn.LayerNorm(self.d_v * self.num_attention_heads, eps=config.layer_norm_eps, elementwise_affine=False)
359
+ self.q_norm = nn.LayerNorm(self.d_qk, eps=config.layer_norm_eps, elementwise_affine=False, bias=False)
360
+ self.k_norm = nn.LayerNorm(self.d_qk, eps=config.layer_norm_eps, elementwise_affine=False, bias=False)
361
+ self.k_scale = nn.Parameter(torch.ones(self.num_kv_heads, self.d_qk))
362
+ self.q_scale = nn.Parameter(torch.ones(self.num_attention_heads, self.d_qk))
363
+
364
+ self.attention_dropout = nn.Dropout(config.attention_dropout)
365
+ self.dropout = nn.Dropout(config.hidden_dropout)
366
+
367
+ theta = 160_000 if (layer_idx + 1) % config.local_global_ratio == 0 else 10_000
368
+
369
+ # Initialize rotary embeddings based on whether FlashAttention is available
370
+ if flash_attn_varlen_qkvpacked_func is not None:
371
+ self.rope_embedding = UnpaddedRotaryEmbedding(dim=self.d_qk, base=theta, max_seqlen=config.max_sequence_length)
372
+ else:
373
+ self.rope_embedding = RotaryPositionalEmbeddings(config, theta)
374
+
375
+ self.scale = 1.0 / math.sqrt(self.d_qk)
376
+ #self.lambdas = nn.Parameter(torch.tensor([0.5]))
377
+
378
+ self.sequence_length = config.max_sequence_length
379
+ self.is_causal = config.is_decoder
380
+ self.window_length = None
381
+
382
+ def set_window_length(self, window_length: int):
383
+ self.window_length = window_length
384
+
385
+ def _get_window_mask(self, query_length: int, key_length: int, device: torch.device):
386
+ """Create and cache window attention mask."""
387
+ if self.is_causal:
388
+ mask = torch.ones(query_length, key_length, dtype=torch.bool, device=device)
389
+ mask = mask.tril().triu(diagonal=-self.window_length)
390
+ else:
391
+ mask = torch.ones(query_length, key_length, dtype=torch.bool, device=device)
392
+ mask = mask.tril(diagonal=self.window_length).triu(diagonal=-self.window_length)
393
+ return mask.view(1, 1, query_length, key_length)
394
+
395
+ def attention_operation(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, padding_mask: Optional[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
396
+ """Standard attention computation with masking."""
397
+ batch_size, _, query_length, _ = query.size()
398
+ _, _, key_length, _ = key.size()
399
+
400
+ # Use cached window mask
401
+ with torch.no_grad():
402
+ window_mask = self._get_window_mask(query_length, key_length, query.device)
403
+ if padding_mask is not None:
404
+ attention_mask = padding_mask & window_mask
405
+ else:
406
+ attention_mask = window_mask
407
+
408
+ attention_scores = torch.bmm(query.flatten(0, 1), key.transpose(-1, -2).flatten(0, 1)) * self.scale # shape: [B*H, Q_T, K_T]
409
+ attention_scores = attention_scores.view(batch_size, self.num_attention_heads, query_length, key_length)
410
+
411
+ attention_probabilities = MaskedSoftmax.apply(attention_scores, ~attention_mask, -1)
412
+ attention_probabilities = self.attention_dropout(attention_probabilities)
413
+
414
+ output = torch.bmm(attention_probabilities.flatten(0, 1), value.flatten(0, 1))
415
+ output = output.view(batch_size, self.num_attention_heads, query_length, self.d_v)
416
+
417
+ return output
418
+
419
+ def forward(self, hidden_layer: torch.Tensor, qk_layer: torch.Tensor, v1: torch.Tensor | None, padding_info):
420
+ # Get original shape info
421
+ if flash_attn_varlen_qkvpacked_func is not None:
422
+ # Unpadded case
423
+ indices, cu_seqlens, max_seqlen = padding_info
424
+ total_seqlen = hidden_layer.size(0)
425
+ batch_size = cu_seqlens.size(0) - 1
426
+ else:
427
+ # Padded case
428
+ batch_size, seq_length = hidden_layer.size(0), hidden_layer.size(1)
429
+
430
+ hidden_layer = self.pre_v_norm(hidden_layer.float()).type_as(hidden_layer)
431
+ qk_layer = self.pre_qk_norm(qk_layer.float()).type_as(qk_layer)
432
+
433
+ query, key = self.qk_proj(qk_layer).tensor_split([self.q_out_dim], dim=-1)
434
+ value = self.v_proj(hidden_layer)
435
+
436
+ if flash_attn_varlen_qkvpacked_func is not None:
437
+ # Reshape for FlashAttention: (total_seqlen, num_heads, head_dim)
438
+ query = query.view(total_seqlen, self.num_attention_heads, self.d_qk)
439
+ key = key.view(total_seqlen, self.num_kv_heads, self.d_qk)
440
+ value = value.view(total_seqlen, self.num_kv_heads, self.d_v)
441
+
442
+ # Apply layer norm and scaling
443
+ query = ((self.q_scale + 1.0).unsqueeze(0) * self.q_norm(query.float())).type_as(query)
444
+ key = ((self.k_scale + 1.0).unsqueeze(0) * self.k_norm(key.float())).type_as(key)
445
+
446
+ # if v1 is None:
447
+ # v1 = value
448
+ # value = (1 - self.lambdas[0]) * value + self.lambdas[0] * v1
449
+
450
+ # Prepare qkv for FlashAttention
451
+ qkv = torch.stack([query, key, value], dim=1) # (total_seqlen, 3, num_heads, head_dim)
452
+
453
+ # Determine window size for local attention
454
+ if self.window_length is not None and self.window_length > 0:
455
+ if self.is_causal:
456
+ local_attention = (self.window_length - 1, 0)
457
+ else:
458
+ local_attention = (self.window_length - 1, self.window_length - 1)
459
+ else:
460
+ local_attention = (-1, -1)
461
+
462
+ # Apply FlashAttention
463
+ output = flash_attention_forward(
464
+ qkv,
465
+ self.rope_embedding,
466
+ cu_seqlens,
467
+ max_seqlen,
468
+ self.is_causal,
469
+ local_attention,
470
+ self.config.attention_dropout if self.training else 0.0,
471
+ self.config.deterministic_flash_attn
472
+ )
473
+
474
+ # Reshape output back
475
+ output = output.view(total_seqlen, self.d_v * self.num_attention_heads)
476
+
477
+ else:
478
+ # Standard attention path
479
+ query_length = query.size(1)
480
+ key_length = key.size(1)
481
+
482
+ query = query.reshape(batch_size, query_length, self.num_attention_heads, self.d_qk).transpose(1, 2)
483
+ key = key.reshape(batch_size, key_length, self.num_kv_heads, self.d_qk).transpose(1, 2)
484
+ value = value.reshape(batch_size, key_length, self.num_kv_heads, self.d_v).transpose(1, 2)
485
+
486
+ query = ((self.q_scale + 1.0).unsqueeze(1).unsqueeze(0) * self.q_norm(query.float())).type_as(query)
487
+ key = ((self.k_scale + 1.0).unsqueeze(1).unsqueeze(0) * self.k_norm(key.float())).type_as(key)
488
+
489
+ # if v1 is None:
490
+ # v1 = value
491
+ # else:
492
+ # value = (1 - self.lambdas[0]) * value + self.lambdas[0] * v1
493
+
494
+ # Apply rotary embeddings
495
+ query = self.rope_embedding(query)
496
+ key = self.rope_embedding(key)
497
+
498
+ output = self.attention_operation(query, key, value, padding_info)
499
+ output = output.transpose(1, 2).flatten(2, 3) # shape: [B, T, H*D]
500
+
501
+ output = self.inter_norm(output.float()).type_as(output)
502
+ output = self.out_proj(output)
503
+ output = self.dropout(output)
504
+
505
+ return output, v1
506
+
507
+
508
+ class FeedForward(nn.Module):
509
+ def __init__(self, config: GptBertConfig):
510
+ super().__init__()
511
+ self.pre_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False)
512
+ self.up_proj = MultiCastedLinearOrthoIn(config.hidden_size, [config.intermediate_size, config.intermediate_size], bias=False)
513
+ self.activation = GeGLU()
514
+ self.inter_norm = nn.LayerNorm(config.intermediate_size, eps=config.layer_norm_eps, elementwise_affine=False)
515
+ self.down_proj = CastedLinearIn(config.intermediate_size, config.hidden_size, bias=False)
516
+ self.dropout = nn.Dropout(config.hidden_dropout)
517
+
518
+ def forward(self, x: torch.Tensor):
519
+ x = self.pre_norm(x.float()).type_as(x)
520
+ x = self.up_proj(x)
521
+ x = self.activation(x)
522
+ x = self.inter_norm(x.float()).type_as(x)
523
+ x = self.down_proj(x)
524
+ x = self.dropout(x)
525
+ return x
526
+
527
+
528
+ class Layer(nn.Module):
529
+
530
+ def __init__(self, config: ModelConfig, layer_idx: int) -> None:
531
+ super().__init__()
532
+
533
+ self.attention: SelfAttention
534
+ self.mlp: FeedForward
535
+
536
+ self.attention = SelfAttention(config, layer_idx)
537
+ self.mlp = FeedForward(config)
538
+ self.lambdas_v = nn.Parameter(torch.tensor([1.0, 0.0]))
539
+ self.lambdas_qk = nn.Parameter(torch.tensor([1.0, 0.0]))
540
+ self.lambdas_mlp = nn.Parameter(torch.tensor([1.0, 1.0, 0.0]))
541
+ self.lambdas_out = nn.Parameter(torch.tensor([1.0, 1.0, 1.0, 0.0]))
542
+
543
+ def set_window_length(self, window_length: int) -> None:
544
+ self.attention.set_window_length(window_length)
545
+
546
+ def normalize_lambda(self, lambdas: torch.Tensor) -> torch.Tensor:
547
+ lambdas = lambdas / (lambdas.abs().mean() + 1e-6)
548
+ return lambdas
549
+
550
+ def forward(self, hidden_layer: torch.Tensor, embeddings: torch.Tensor, v1: torch.Tensor | None, padding_info):
551
+ output: torch.Tensor
552
+
553
+ lambdas_v = self.normalize_lambda(self.lambdas_v)
554
+ lambdas_qk = self.normalize_lambda(self.lambdas_qk)
555
+ lambdas_mlp = self.normalize_lambda(self.lambdas_mlp)
556
+ lambdas_out = self.normalize_lambda(self.lambdas_out)
557
+
558
+ v_layer = (lambdas_v[0] * hidden_layer) + (lambdas_v[1] * embeddings)
559
+ qk_layer = (lambdas_qk[0] * hidden_layer) + (lambdas_qk[1] * embeddings)
560
+ attention_output, v1 = self.attention(v_layer, qk_layer, v1, padding_info)
561
+
562
+ mlp_layer = (lambdas_mlp[0] * attention_output) + (lambdas_mlp[1] * hidden_layer) + (lambdas_mlp[2] * embeddings)
563
+ mlp_layer = self.mlp(mlp_layer)
564
+
565
+ output = (lambdas_out[0] * mlp_layer) + (lambdas_out[1] * attention_output) + (lambdas_out[2] * hidden_layer) + (lambdas_out[3] * embeddings)
566
+
567
+ return output, v1
568
+
569
+
570
+ class Encoder(nn.Module):
571
+ def __init__(self, config: GptBertConfig):
572
+ super().__init__()
573
+ self.layers = nn.ModuleList([Layer(config, i) for i in range(config.num_layers)])
574
+ self.local_global_ratio = config.local_global_ratio
575
+
576
+ def set_window_length(self, config: GptBertConfig):
577
+ for i, layer in enumerate(self.layers):
578
+ if (i + 1) % self.local_global_ratio == 0:
579
+ layer.set_window_length(config.global_window_length)
580
+ else:
581
+ layer.set_window_length(config.local_window_length)
582
+
583
+ def forward(self, hidden_layer: torch.Tensor, padding_info, output_hidden_states=False, checkpoint_activations=False):
584
+ hidden_layers = [hidden_layer] if output_hidden_states else None
585
+ v1 = None
586
+ embeddings = hidden_layer
587
+
588
+ for layer in self.layers:
589
+ if checkpoint_activations:
590
+ hidden_layer, v1 = torch.utils.checkpoint.checkpoint(layer, hidden_layer, embeddings, v1, padding_info, use_reentrant=True)
591
+ else:
592
+ hidden_layer, v1 = layer(hidden_layer, embeddings, v1, padding_info)
593
+
594
+ if output_hidden_states:
595
+ hidden_layers.append(hidden_layer)
596
+
597
+ return hidden_layer, hidden_layers
598
+
599
+
600
+ #
601
+ # HuggingFace wrappers
602
+ #
603
+
604
+ class GptBertPreTrainedModel(PreTrainedModel):
605
+ config_class = GptBertConfig
606
+ supports_gradient_checkpointing = True
607
+ _supports_flash_attn_2 = True
608
+ _supports_sdpa = True
609
+ _supports_flex_attn = False
610
+
611
+ def _init_weights(self, module):
612
+ std = math.sqrt(2.0 / (5.0 * self.hidden_size))
613
+
614
+ if isinstance(module, nn.Linear) or isinstance(module, CastedLinearIn):
615
+ nn.init.trunc_normal_(module.weight.data, mean=0.0, std=std, a=-2*std, b=2*std)
616
+ if module.bias is not None:
617
+ module.bias.data.zero_()
618
+ elif isinstance(module, nn.Embedding):
619
+ nn.init.trunc_normal_(module.weight.data, mean=0.0, std=std, a=-2*std, b=2*std)
620
+ elif isinstance(module, nn.LayerNorm):
621
+ module.bias.data.zero_()
622
+ module.weight.data.fill_(1.0)
623
+
624
+
625
+ class GptBertModel(GptBertPreTrainedModel):
626
+ def __init__(self, config: GptBertConfig, add_mlm_layer=False, **kwargs):
627
+ super().__init__(config, **kwargs)
628
+ self.config = config
629
+ self.hidden_size = config.hidden_size
630
+
631
+ self.embedding = Embedding(config)
632
+ self.encoder = Encoder(config)
633
+ self.classifier = LMClassifier(config, config.vocab_size) if add_mlm_layer else None
634
+ self.set_window_length(config)
635
+ self.gradient_checkpointing = False
636
+ self.post_init()
637
+
638
+ def set_window_length(self, config) -> None:
639
+ self.encoder.set_window_length(config)
640
+
641
+ def get_input_embeddings(self):
642
+ return self.embedding.word_embedding
643
+
644
+ def set_input_embeddings(self, value):
645
+ self.embedding.word_embedding = value
646
+
647
+ def get_contextualized_embeddings(
648
+ self,
649
+ input_ids: Optional[torch.Tensor] = None,
650
+ attention_mask: Optional[torch.Tensor] = None,
651
+ output_hidden_states: Optional[bool] = None
652
+ ):
653
+ if input_ids is not None:
654
+ input_shape = input_ids.size()
655
+ else:
656
+ raise ValueError("You have to specify input_ids")
657
+
658
+ batch_size, seq_length = input_shape
659
+ device = input_ids.device
660
+
661
+ if attention_mask is None:
662
+ attention_mask = torch.ones(batch_size, seq_length, dtype=torch.bool, device=device)
663
+ else:
664
+ attention_mask = attention_mask.bool()
665
+
666
+ if flash_attn_varlen_qkvpacked_func is not None:
667
+ if len(attention_mask.size()) != 2:
668
+ raise ValueError("Bare `attention_mask` med to dimensjoner støttes nå for FlashAttention.")
669
+ with torch.no_grad():
670
+ input_ids, indices, cu_seqlens, max_seqlen_in_batch = _unpad_input(input_ids, attention_mask)
671
+ padding_info = (indices, cu_seqlens, max_seqlen_in_batch)
672
+ else:
673
+ if len(attention_mask.size()) == 2:
674
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
675
+ elif len(attention_mask.size()) == 3:
676
+ attention_mask = attention_mask.unsqueeze(1)
677
+ padding_info = attention_mask
678
+
679
+ static_embeddings = self.embedding(input_ids)
680
+
681
+ original_dtype = static_embeddings.dtype
682
+ if torch.cuda.is_available() and torch.cuda.is_bf16_supported() and static_embeddings.dtype == torch.float32:
683
+ static_embeddings = static_embeddings.bfloat16()
684
+
685
+ last_layer, contextualized_embeddings = self.encoder(
686
+ static_embeddings,
687
+ padding_info,
688
+ output_hidden_states=output_hidden_states,
689
+ checkpoint_activations=self.gradient_checkpointing and self.training
690
+ )
691
+
692
+ last_layer = last_layer.to(original_dtype)
693
+ if output_hidden_states:
694
+ contextualized_embeddings = [layer.to(original_dtype) for layer in contextualized_embeddings]
695
+
696
+ # Pad output if using FlashAttention
697
+ if flash_attn_varlen_qkvpacked_func is not None:
698
+ last_layer = _pad_output(last_layer, indices, batch_size, seq_length)
699
+ if output_hidden_states:
700
+ contextualized_embeddings = [_pad_output(layer, indices, batch_size, seq_length) for layer in contextualized_embeddings]
701
+ else:
702
+ contextualized_embeddings = None
703
+
704
+ return last_layer, contextualized_embeddings
705
+
706
+ def forward(
707
+ self,
708
+ input_ids: Optional[torch.Tensor] = None,
709
+ attention_mask: Optional[torch.Tensor] = None,
710
+ output_hidden_states: Optional[bool] = None,
711
+ output_attentions: Optional[bool] = None,
712
+ return_dict: Optional[bool] = None,
713
+ **kwargs
714
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutput]:
715
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
716
+
717
+ sequence_output, contextualized_embeddings = self.get_contextualized_embeddings(input_ids, attention_mask, output_hidden_states)
718
+
719
+ if not return_dict:
720
+ return (
721
+ sequence_output,
722
+ *([contextualized_embeddings] if output_hidden_states else [])
723
+ )
724
+
725
+ return BaseModelOutput(
726
+ last_hidden_state=sequence_output,
727
+ hidden_states=contextualized_embeddings if output_hidden_states else None
728
+ )
729
+
730
+
731
+ class GptBertForMaskedLM(GptBertModel):
732
+ _tied_weights_keys = ["classifier.emb2vocab.weight"]
733
+
734
+ def __init__(self, config: GptBertConfig, **kwargs):
735
+ super().__init__(config, add_mlm_layer=True, **kwargs)
736
+
737
+ def get_output_embeddings(self):
738
+ return self.classifier.emb2vocab.weight
739
+
740
+ def set_output_embeddings(self, new_embeddings):
741
+ self.classifier.emb2vocab.weight = new_embeddings
742
+
743
+ def forward(
744
+ self,
745
+ input_ids: Optional[torch.Tensor] = None,
746
+ attention_mask: Optional[torch.Tensor] = None,
747
+ output_hidden_states: Optional[bool] = None,
748
+ return_dict: Optional[bool] = None,
749
+ labels: Optional[torch.LongTensor] = None,
750
+ **kwargs
751
+ ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
752
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
753
+
754
+ sequence_output, contextualized_embeddings = self.get_contextualized_embeddings(input_ids, attention_mask, output_hidden_states)
755
+ subword_prediction = self.classifier(sequence_output)
756
+ subword_prediction = 30 * torch.sigmoid(subword_prediction / 7.5)
757
+
758
+ masked_lm_loss = None
759
+ if labels is not None:
760
+ labels_flatten = labels[:, 1:].flatten()
761
+ subword_prediction_flatten = subword_prediction[:, :-1].flatten(0, 1)
762
+ masked_lm_loss = F.cross_entropy(subword_prediction_flatten, labels_flatten)
763
+
764
+ bos_logits = torch.zeros(subword_prediction.size(0), 1, self.config.vocab_size, dtype=subword_prediction.dtype, device=subword_prediction.device)
765
+ bos_logits[:, :, self.config.bos_token_id] = 1.0
766
+ subword_prediction = torch.cat([bos_logits, subword_prediction[:, :-1]], dim=1)
767
+
768
+ if not return_dict:
769
+ output = (
770
+ subword_prediction,
771
+ *([contextualized_embeddings] if output_hidden_states else [])
772
+ )
773
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
774
+
775
+ return MaskedLMOutput(
776
+ loss=masked_lm_loss,
777
+ logits=subword_prediction,
778
+ hidden_states=contextualized_embeddings if output_hidden_states else None
779
+ )
780
+
781
+
782
+ class GptBertForCausalLM(GptBertModel):
783
+ _tied_weights_keys = ["classifier.emb2vocab.weight"]
784
+
785
+ def __init__(self, config: GptBertConfig, **kwargs):
786
+ config.is_decoder = True
787
+ super().__init__(config, add_mlm_layer=True, **kwargs)
788
+
789
+ def get_output_embeddings(self):
790
+ return self.classifier.emb2vocab.weight
791
+
792
+ def set_output_embeddings(self, new_embeddings):
793
+ self.classifier.emb2vocab.weight = new_embeddings
794
+
795
+ def get_input_embeddings(self):
796
+ return self.embedding.word_embedding
797
+
798
+ def set_input_embeddings(self, value):
799
+ self.embedding.word_embedding = value
800
+
801
+ def set_decoder(self, decoder):
802
+ self.encoder = decoder
803
+
804
+ def get_decoder(self):
805
+ return self.encoder
806
+
807
+ def can_generate(self):
808
+ return True
809
+
810
+ def forward(
811
+ self,
812
+ input_ids: torch.LongTensor = None,
813
+ attention_mask: Optional[torch.Tensor] = None,
814
+ position_ids: Optional[torch.LongTensor] = None,
815
+ token_type_ids: Optional[torch.Tensor] = None,
816
+ past_key_values: Optional[torch.Tensor] = None,
817
+ inputs_embeds: Optional[torch.FloatTensor] = None,
818
+ labels: Optional[torch.LongTensor] = None,
819
+ use_cache: Optional[bool] = None,
820
+ cache_position: Optional[torch.LongTensor] = None,
821
+ output_attentions: Optional[bool] = None,
822
+ output_hidden_states: Optional[bool] = None,
823
+ return_dict: Optional[bool] = None
824
+ ) -> Union[Tuple, CausalLMOutput]:
825
+
826
+ assert inputs_embeds is None, "inputs_embeds is not supported for now"
827
+ assert past_key_values is None, "past_key_values is not supported for now"
828
+ assert not use_cache, "use_cache is not supported for now"
829
+
830
+ sequence_output, contextualized_embeddings = self.get_contextualized_embeddings(input_ids, attention_mask, output_hidden_states)
831
+ subword_prediction = self.classifier(sequence_output)
832
+ subword_prediction = 30 * torch.sigmoid(subword_prediction / 7.5)
833
+
834
+ causal_lm_loss = None
835
+ if labels is not None:
836
+ labels_flatten = labels[:, 1:].flatten()
837
+ subword_prediction_flatten = subword_prediction[:, :-1].flatten(0, 1)
838
+ causal_lm_loss = F.cross_entropy(subword_prediction_flatten, labels_flatten)
839
+
840
+ if not return_dict:
841
+ output = (
842
+ subword_prediction,
843
+ *([contextualized_embeddings] if output_hidden_states else [])
844
+ )
845
+ return ((causal_lm_loss,) + output) if masked_lm_loss is not None else output
846
+
847
+ return CausalLMOutput(
848
+ loss=causal_lm_loss,
849
+ logits=subword_prediction,
850
+ hidden_states=contextualized_embeddings if output_hidden_states else None
851
+ )
852
+
853
+ def prepare_inputs_for_generation(
854
+ self,
855
+ input_ids: torch.Tensor,
856
+ past_key_values: Optional[torch.Tensor] = None,
857
+ attention_mask: Optional[torch.Tensor] = None,
858
+ inputs_embeds: Optional[torch.Tensor] = None,
859
+ cache_position: Optional[torch.LongTensor] = None,
860
+ position_ids: Optional[torch.LongTensor] = None,
861
+ use_cache: bool = True,
862
+ num_logits_to_keep: Optional[int] = None,
863
+ **kwargs,
864
+ ):
865
+ # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
866
+ # Exception 1: when passing input_embeds, input_ids may be missing entries
867
+ # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
868
+ if past_key_values is not None:
869
+ if inputs_embeds is not None: # Exception 1
870
+ input_ids = input_ids[:, -cache_position.shape[0] :]
871
+ elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
872
+ input_ids = input_ids[:, cache_position]
873
+
874
+ if attention_mask is not None and position_ids is None:
875
+ # create position_ids on the fly for batch generation
876
+ position_ids = attention_mask.long().cumsum(-1) - 1
877
+ position_ids.masked_fill_(attention_mask == 0, 1)
878
+ if past_key_values:
879
+ position_ids = position_ids[:, -input_ids.shape[1] :]
880
+
881
+ # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
882
+ position_ids = position_ids.clone(memory_format=torch.contiguous_format)
883
+
884
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
885
+ if inputs_embeds is not None and cache_position[0] == 0:
886
+ model_inputs = {"inputs_embeds": inputs_embeds}
887
+ else:
888
+ model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
889
+
890
+ if num_logits_to_keep is not None:
891
+ model_inputs["num_logits_to_keep"] = num_logits_to_keep
892
+
893
+ model_inputs.update(
894
+ {
895
+ "position_ids": position_ids,
896
+ "cache_position": cache_position,
897
+ "past_key_values": past_key_values,
898
+ "use_cache": use_cache,
899
+ "attention_mask": attention_mask,
900
+ }
901
+ )
902
+ return model_inputs
903
+
904
+
905
+ class GptBertForSequenceClassification(GptBertModel):
906
+ _keys_to_ignore_on_load_missing = ["classifier.emb2vocab.weight", "classifier.emb2vocab.bias"]
907
+ _keys_to_ignore_on_load_unexpected = ["classifier.emb2vocab.weight", "classifier.emb2vocab.bias"]
908
+
909
+ def __init__(self, config: GptBertConfig, **kwargs):
910
+ super().__init__(config, add_mlm_layer=False, **kwargs)
911
+
912
+ self.num_labels = config.num_labels
913
+ self.classifier = Classifier(config, self.num_labels)
914
+ self.post_init()
915
+
916
+ def forward(
917
+ self,
918
+ input_ids: Optional[torch.Tensor] = None,
919
+ attention_mask: Optional[torch.Tensor] = None,
920
+ output_hidden_states: Optional[bool] = None,
921
+ return_dict: Optional[bool] = None,
922
+ labels: Optional[torch.LongTensor] = None,
923
+ **kwargs
924
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
925
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
926
+
927
+ sequence_output, contextualized_embeddings = self.get_contextualized_embeddings(input_ids, attention_mask, output_hidden_states)
928
+ logits = self.classifier(sequence_output[:, 0, :])
929
+
930
+ loss = None
931
+ if labels is not None:
932
+ if self.config.problem_type is None:
933
+ if self.num_labels == 1:
934
+ self.config.problem_type = "regression"
935
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
936
+ self.config.problem_type = "single_label_classification"
937
+ else:
938
+ self.config.problem_type = "multi_label_classification"
939
+
940
+ if self.config.problem_type == "regression":
941
+ loss_fct = nn.MSELoss()
942
+ if self.num_labels == 1:
943
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
944
+ else:
945
+ loss = loss_fct(logits, labels)
946
+ elif self.config.problem_type == "single_label_classification":
947
+ loss_fct = nn.CrossEntropyLoss()
948
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
949
+ elif self.config.problem_type == "multi_label_classification":
950
+ loss_fct = nn.BCEWithLogitsLoss()
951
+ loss = loss_fct(logits, labels)
952
+
953
+ if not return_dict:
954
+ output = (
955
+ logits,
956
+ *([contextualized_embeddings] if output_hidden_states else [])
957
+ )
958
+ return ((loss,) + output) if loss is not None else output
959
+
960
+ return SequenceClassifierOutput(
961
+ loss=loss,
962
+ logits=logits,
963
+ hidden_states=contextualized_embeddings if output_hidden_states else None
964
+ )
965
+
966
+
967
+ class GptBertForTokenClassification(GptBertModel):
968
+ _keys_to_ignore_on_load_missing = ["classifier.emb2vocab.weight", "classifier.emb2vocab.bias"]
969
+ _keys_to_ignore_on_load_unexpected = ["classifier.emb2vocab.weight", "classifier.emb2vocab.bias"]
970
+
971
+ def __init__(self, config: GptBertConfig, **kwargs):
972
+ super().__init__(config, add_mlm_layer=False, **kwargs)
973
+
974
+ self.num_labels = config.num_labels
975
+ self.classifier = Classifier(config, self.num_labels)
976
+ self.post_init()
977
+
978
+ def forward(
979
+ self,
980
+ input_ids: Optional[torch.Tensor] = None,
981
+ attention_mask: Optional[torch.Tensor] = None,
982
+ output_hidden_states: Optional[bool] = None,
983
+ return_dict: Optional[bool] = None,
984
+ labels: Optional[torch.LongTensor] = None,
985
+ **kwargs
986
+ ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
987
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
988
+
989
+ sequence_output, contextualized_embeddings = self.get_contextualized_embeddings(input_ids, attention_mask, output_hidden_states)
990
+ logits = self.classifier(sequence_output)
991
+
992
+ loss = None
993
+ if labels is not None:
994
+ loss_fct = nn.CrossEntropyLoss()
995
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
996
+
997
+ if not return_dict:
998
+ output = (
999
+ logits,
1000
+ *([contextualized_embeddings] if output_hidden_states else []),
1001
+ *([attention_probs] if output_attentions else [])
1002
+ )
1003
+ return ((loss,) + output) if loss is not None else output
1004
+
1005
+ return TokenClassifierOutput(
1006
+ loss=loss,
1007
+ logits=logits,
1008
+ hidden_states=contextualized_embeddings if output_hidden_states else None,
1009
+ attentions=attention_probs if output_attentions else None
1010
+ )
1011
+
1012
+
1013
+ class GptBertForQuestionAnswering(GptBertModel):
1014
+ _keys_to_ignore_on_load_missing = ["classifier.emb2vocab.weight", "classifier.emb2vocab.bias"]
1015
+ _keys_to_ignore_on_load_unexpected = ["classifier.emb2vocab.weight", "classifier.emb2vocab.bias"]
1016
+
1017
+ def __init__(self, config: GptBertConfig, **kwargs):
1018
+ super().__init__(config, add_mlm_layer=False, **kwargs)
1019
+
1020
+ self.num_labels = config.num_labels
1021
+ self.classifier = Classifier(config, self.num_labels)
1022
+ self.post_init()
1023
+
1024
+ def forward(
1025
+ self,
1026
+ input_ids: Optional[torch.Tensor] = None,
1027
+ attention_mask: Optional[torch.Tensor] = None,
1028
+ output_hidden_states: Optional[bool] = None,
1029
+ return_dict: Optional[bool] = None,
1030
+ start_positions: Optional[torch.Tensor] = None,
1031
+ end_positions: Optional[torch.Tensor] = None,
1032
+ **kwargs
1033
+ ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
1034
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1035
+
1036
+ sequence_output, contextualized_embeddings = self.get_contextualized_embeddings(input_ids, attention_mask, output_hidden_states)
1037
+ logits = self.classifier(sequence_output)
1038
+
1039
+ start_logits, end_logits = logits.split(1, dim=-1)
1040
+ start_logits = start_logits.squeeze(-1).contiguous()
1041
+ end_logits = end_logits.squeeze(-1).contiguous()
1042
+
1043
+ total_loss = None
1044
+ if start_positions is not None and end_positions is not None:
1045
+ # If we are on multi-GPU, split add a dimension
1046
+ if len(start_positions.size()) > 1:
1047
+ start_positions = start_positions.squeeze(-1)
1048
+ if len(end_positions.size()) > 1:
1049
+ end_positions = end_positions.squeeze(-1)
1050
+
1051
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1052
+ ignored_index = start_logits.size(1)
1053
+ start_positions = start_positions.clamp(0, ignored_index)
1054
+ end_positions = end_positions.clamp(0, ignored_index)
1055
+
1056
+ loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index)
1057
+ start_loss = loss_fct(start_logits, start_positions)
1058
+ end_loss = loss_fct(end_logits, end_positions)
1059
+ total_loss = (start_loss + end_loss) / 2
1060
+
1061
+ if not return_dict:
1062
+ output = (
1063
+ start_logits,
1064
+ end_logits,
1065
+ *([contextualized_embeddings] if output_hidden_states else [])
1066
+ )
1067
+ return ((total_loss,) + output) if total_loss is not None else output
1068
+
1069
+ return QuestionAnsweringModelOutput(
1070
+ loss=total_loss,
1071
+ start_logits=start_logits,
1072
+ end_logits=end_logits,
1073
+ hidden_states=contextualized_embeddings if output_hidden_states else None
1074
+ )
1075
+
1076
+
1077
+ class GptBertForMultipleChoice(GptBertModel):
1078
+ _keys_to_ignore_on_load_missing = ["classifier.emb2vocab.weight", "classifier.emb2vocab.bias"]
1079
+ _keys_to_ignore_on_load_unexpected = ["classifier.emb2vocab.weight", "classifier.emb2vocab.bias"]
1080
+
1081
+ def __init__(self, config: GptBertConfig, **kwargs):
1082
+ super().__init__(config, add_mlm_layer=False, **kwargs)
1083
+
1084
+ self.num_labels = getattr(config, "num_labels", 2)
1085
+ self.classifier = Classifier(config, self.num_labels)
1086
+ self.post_init()
1087
+
1088
+ def forward(
1089
+ self,
1090
+ input_ids: Optional[torch.Tensor] = None,
1091
+ attention_mask: Optional[torch.Tensor] = None,
1092
+ labels: Optional[torch.Tensor] = None,
1093
+ output_hidden_states: Optional[bool] = None,
1094
+ return_dict: Optional[bool] = None,
1095
+ **kwargs
1096
+ ) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]:
1097
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1098
+ num_choices = input_ids.shape[1]
1099
+
1100
+ flat_input_ids = input_ids.view(-1, input_ids.size(-1))
1101
+ flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
1102
+
1103
+ sequence_output, contextualized_embeddings = self.get_contextualized_embeddings(flat_input_ids, flat_attention_mask, output_hidden_states)
1104
+ logits = self.classifier(sequence_output)
1105
+ reshaped_logits = logits.view(-1, num_choices)
1106
+
1107
+ loss = None
1108
+ if labels is not None:
1109
+ loss_fct = nn.CrossEntropyLoss()
1110
+ loss = loss_fct(reshaped_logits, labels)
1111
+
1112
+ if not return_dict:
1113
+ output = (
1114
+ reshaped_logits,
1115
+ *([contextualized_embeddings] if output_hidden_states else [])
1116
+ )
1117
+ return ((loss,) + output) if loss is not None else output
1118
+
1119
+ return MultipleChoiceModelOutput(
1120
+ loss=loss,
1121
+ logits=reshaped_logits,
1122
+ hidden_states=contextualized_embeddings if output_hidden_states else None
1123
+ )
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:25dd6b44c778dc78c07a249d05ff982dd6c1d395c625fb09ea42f1e855610571
3
+ size 597586413
special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"bos_token": "[CLS]", "eos_token": "[SEP]", "unk_token": "[UNK]", "sep_token": "[CLS]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]"}
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "tokenizer_class": "PreTrainedTokenizerFast",
3
+ "bos_token": "[CLS]",
4
+ "eos_token": "[SEP]",
5
+ "unk_token": "[UNK]",
6
+ "sep_token": "[SEP]",
7
+ "pad_token": "[PAD]",
8
+ "cls_token": "[CLS]",
9
+ "mask_token": "[MASK]"
10
+ }