pszemraj commited on
Commit
a07e32a
·
verified ·
0 Parent(s):

Super-squash branch 'main' using huggingface_hub

Browse files
Files changed (10) hide show
  1. .gitattributes +35 -0
  2. README.md +203 -0
  3. config.json +24 -0
  4. model.py +505 -0
  5. model.safetensors +3 -0
  6. pytorch_model.bin +3 -0
  7. rotary.py +62 -0
  8. special_tokens_map.json +51 -0
  9. tokenizer.json +0 -0
  10. tokenizer_config.json +60 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ license: mit
4
+ datasets:
5
+ - EleutherAI/SmolLM2-1.7B-stage-4-100B
6
+ language:
7
+ - en
8
+ ---
9
+
10
+ # NeoBERT Model
11
+
12
+ This is a NeoBERT model trained with [pszemraj/NeoBERT](https://github.com/pszemraj/NeoBERT) and exported to `transformers` format.
13
+
14
+ ## Model Details
15
+ - **Architecture**: NeoBERT
16
+ - **Hidden Size**: 768
17
+ - **Layers**: 12
18
+ - **Attention Heads**: 12
19
+ - **Vocab Size**: 31999
20
+ - **Max Length**: 4096
21
+ - **Dtype**: float32
22
+
23
+ ## Usage
24
+
25
+ ### For Masked Language Modeling (Fill-Mask)
26
+
27
+ ```python
28
+ from transformers import AutoModelForMaskedLM, AutoTokenizer
29
+ import torch
30
+
31
+ repo_id = "BEE-spoke-data/neobert-100k-test"
32
+ tokenizer = AutoTokenizer.from_pretrained(repo_id, trust_remote_code=True)
33
+ model = AutoModelForMaskedLM.from_pretrained(repo_id, trust_remote_code=True)
34
+
35
+ # Example: Fill in masked tokens
36
+ text = "NeoBERT is the most [MASK] model of its kind!"
37
+
38
+ # Tokenize (handling Metaspace tokenizer's space tokens)
39
+ inputs = tokenizer(text, return_tensors="pt")
40
+ input_ids = inputs["input_ids"][0].tolist()
41
+
42
+ # Remove extra space tokens before [MASK] if present (Metaspace tokenizer quirk)
43
+ cleaned_ids = []
44
+ for i, token_id in enumerate(input_ids):
45
+ if token_id == 454 and i < len(input_ids) - 1 and input_ids[i + 1] == tokenizer.mask_token_id:
46
+ continue
47
+ cleaned_ids.append(token_id)
48
+
49
+ if len(cleaned_ids) != len(input_ids):
50
+ inputs["input_ids"] = torch.tensor([cleaned_ids])
51
+ inputs["attention_mask"] = torch.ones_like(inputs["input_ids"])
52
+
53
+ # Get predictions
54
+ with torch.no_grad():
55
+ outputs = model(**inputs)
56
+ mask_pos = (inputs["input_ids"] == tokenizer.mask_token_id).nonzero(as_tuple=True)[1][0]
57
+ predictions = outputs.logits[0, mask_pos].topk(5)
58
+
59
+ # Display top predictions
60
+ for idx, score in zip(predictions.indices, predictions.values):
61
+ token = tokenizer.decode([idx])
62
+ print(f"{token}: {score:.2f}")
63
+ ```
64
+
65
+ ### For Embeddings / Feature Extraction
66
+
67
+ ```python
68
+ from transformers import AutoModel, AutoTokenizer
69
+
70
+ repo_id = "BEE-spoke-data/neobert-100k-test"
71
+ tokenizer = AutoTokenizer.from_pretrained(repo_id, trust_remote_code=True)
72
+ model = AutoModel.from_pretrained(repo_id, trust_remote_code=True)
73
+
74
+ # Example: Generate embeddings
75
+ text = "NeoBERT is an efficient transformer model!"
76
+ inputs = tokenizer(text, return_tensors="pt")
77
+ outputs = model(**inputs)
78
+
79
+ # Get CLS token embedding
80
+ cls_embedding = outputs.last_hidden_state[:, 0, :]
81
+ print(f"Embedding shape: {cls_embedding.shape}")
82
+ ```
83
+
84
+ ## Training Configuration
85
+
86
+ <details>
87
+ <summary><strong>Full Config</strong> (click to expand)</summary>
88
+
89
+ Full training config:
90
+
91
+ ```yaml
92
+ model:
93
+ hidden_size: 768
94
+ num_hidden_layers: 12
95
+ num_attention_heads: 12
96
+ intermediate_size: 3072
97
+ max_position_embeddings: 4096
98
+ vocab_size: 31999
99
+ rope: true
100
+ rms_norm: true
101
+ hidden_act: swiglu
102
+ dropout_prob: 0.05
103
+ norm_eps: 1.0e-05
104
+ embedding_init_range: 0.02
105
+ decoder_init_range: 0.02
106
+ classifier_init_range: 0.02
107
+ flash_attention: true
108
+ ngpt: false
109
+ base_scale: 0.03227486121839514
110
+ pad_token_id: 0
111
+ dataset:
112
+ name: EleutherAI/SmolLM2-1.7B-stage-4-100B
113
+ path: ''
114
+ num_workers: 4
115
+ streaming: true
116
+ cache_dir: null
117
+ max_seq_length: 1024
118
+ validation_split: null
119
+ train_split: train
120
+ eval_split: train[:1%]
121
+ num_proc: 8
122
+ shuffle_buffer_size: 10000
123
+ pre_tokenize: false
124
+ pre_tokenize_output: null
125
+ load_all_from_disk: false
126
+ force_redownload: false
127
+ pretraining_prob: 0.3
128
+ min_length: 512
129
+ tokenizer:
130
+ name: BEE-spoke-data/wordpiece-tokenizer-32k-en_code-msp
131
+ path: null
132
+ max_length: 1024
133
+ padding: max_length
134
+ truncation: true
135
+ vocab_size: 31999
136
+ optimizer:
137
+ name: adamw
138
+ lr: 0.0001
139
+ weight_decay: 0.01
140
+ betas:
141
+ - 0.9
142
+ - 0.98
143
+ eps: 1.0e-08
144
+ scheduler:
145
+ name: cosine
146
+ warmup_steps: 5000
147
+ total_steps: null
148
+ num_cycles: 0.5
149
+ decay_steps: 50000
150
+ warmup_percent: null
151
+ decay_percent: null
152
+ trainer:
153
+ per_device_train_batch_size: 16
154
+ per_device_eval_batch_size: 16
155
+ gradient_accumulation_steps: 4
156
+ max_steps: 100000
157
+ save_steps: 10000
158
+ eval_steps: 5000
159
+ logging_steps: 25
160
+ output_dir: ./outputs/neobert_100m_100k
161
+ overwrite_output_dir: true
162
+ bf16: true
163
+ gradient_checkpointing: false
164
+ gradient_clipping: null
165
+ mixed_precision: 'no'
166
+ seed: 42
167
+ resume_from_checkpoint: false
168
+ disable_tqdm: false
169
+ dataloader_num_workers: 0
170
+ use_cpu: false
171
+ report_to:
172
+ - wandb
173
+ tf32: true
174
+ max_ckpt: 3
175
+ train_batch_size: 16
176
+ eval_batch_size: 32
177
+ datacollator:
178
+ mlm_probability: 0.2
179
+ pad_to_multiple_of: 8
180
+ wandb:
181
+ project: neobert-pretraining
182
+ entity: null
183
+ name: neobert-100m-100k
184
+ tags: []
185
+ mode: online
186
+ log_interval: 100
187
+ resume: never
188
+ dir: logs/wandb
189
+ task: pretraining
190
+ accelerate_config_file: null
191
+ mixed_precision: bf16
192
+ mteb_task_type: all
193
+ mteb_batch_size: 32
194
+ mteb_pooling: mean
195
+ mteb_overwrite_results: false
196
+ pretrained_checkpoint: latest
197
+ use_deepspeed: true
198
+ seed: 69
199
+ debug: false
200
+
201
+ ```
202
+
203
+ </details>
config.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "NeoBERTLMHead"
4
+ ],
5
+ "model_type": "neobert",
6
+ "auto_map": {
7
+ "AutoConfig": "model.NeoBERTConfig",
8
+ "AutoModel": "model.NeoBERT",
9
+ "AutoModelForMaskedLM": "model.NeoBERTLMHead",
10
+ "AutoModelForSequenceClassification": "model.NeoBERTForSequenceClassification"
11
+ },
12
+ "hidden_size": 768,
13
+ "num_hidden_layers": 12,
14
+ "num_attention_heads": 12,
15
+ "intermediate_size": 3072,
16
+ "vocab_size": 31999,
17
+ "max_length": 4096,
18
+ "embedding_init_range": 0.02,
19
+ "decoder_init_range": 0.02,
20
+ "norm_eps": 1e-05,
21
+ "pad_token_id": 0,
22
+ "torch_dtype": "float32",
23
+ "transformers_version": "4.55.0"
24
+ }
model.py ADDED
@@ -0,0 +1,505 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # From https://github.com/facebookresearch/llama/blob/main/llama/model.py
2
+
3
+ from typing import Optional
4
+
5
+ import numpy as np
6
+ import torch
7
+ from torch import nn
8
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
9
+ from torch.nn.functional import scaled_dot_product_attention
10
+
11
+ try:
12
+ from xformers.ops import SwiGLU
13
+
14
+ XFORMERS_AVAILABLE = True
15
+ except ImportError:
16
+ XFORMERS_AVAILABLE = False
17
+
18
+ try:
19
+ from flash_attn.flash_attn_interface import flash_attn_varlen_func
20
+
21
+ FLASH_ATTN_AVAILABLE = True
22
+ except ImportError:
23
+ FLASH_ATTN_AVAILABLE = False
24
+
25
+ from transformers import (
26
+ DataCollatorForLanguageModeling,
27
+ PretrainedConfig,
28
+ PreTrainedModel,
29
+ )
30
+ from transformers.modeling_outputs import (
31
+ BaseModelOutput,
32
+ MaskedLMOutput,
33
+ SequenceClassifierOutput,
34
+ )
35
+
36
+ from .rotary import apply_rotary_emb, precompute_freqs_cis
37
+
38
+
39
+ class DataCollatorWithPacking(DataCollatorForLanguageModeling):
40
+ def __init__(self, pack_sequences=False, **kwargs):
41
+ super().__init__(**kwargs)
42
+ self.pack_sequences = pack_sequences
43
+
44
+ def __call__(self, batch):
45
+ if self.pack_sequences:
46
+ # Add position_ids if not present
47
+ if "position_ids" not in batch[0]:
48
+ for item in batch:
49
+ item["position_ids"] = list(range(len(item["input_ids"])))
50
+
51
+ # Pack the sequences into a single list
52
+ input_ids_list = [item["input_ids"] for item in batch]
53
+ position_ids_list = [item["position_ids"] for item in batch]
54
+ seqlens = np.array([0] + [len(ids) for ids in input_ids_list])
55
+
56
+ packed_batch = {
57
+ "position_ids": np.concatenate(position_ids_list, axis=0),
58
+ "input_ids": np.concatenate(input_ids_list, axis=0),
59
+ "cu_seqlens": np.cumsum(seqlens),
60
+ "max_seqlen": max(seqlens),
61
+ }
62
+
63
+ batch = super().__call__([packed_batch])
64
+ batch["cu_seqlens"] = batch["cu_seqlens"].to(torch.int32).squeeze()
65
+ else:
66
+ batch = super().__call__(batch)
67
+ batch["attention_mask"] = batch["attention_mask"].to(torch.bool)
68
+
69
+ return batch
70
+
71
+
72
+ class NeoBERTConfig(PretrainedConfig):
73
+ model_type = "neobert"
74
+
75
+ # All config parameters must have a default value.
76
+ def __init__(
77
+ self,
78
+ hidden_size: int = 768,
79
+ num_hidden_layers: int = 28,
80
+ num_attention_heads: int = 12,
81
+ intermediate_size: int = 3072,
82
+ embedding_init_range: float = 0.02,
83
+ decoder_init_range: float = 0.02,
84
+ norm_eps: float = 1e-06,
85
+ vocab_size: int = 30522,
86
+ pad_token_id: int = 0,
87
+ max_length: int = 1024,
88
+ **kwargs,
89
+ ):
90
+ super().__init__(**kwargs)
91
+
92
+ self.hidden_size = hidden_size
93
+ self.num_hidden_layers = num_hidden_layers
94
+ self.num_attention_heads = num_attention_heads
95
+ if hidden_size % num_attention_heads != 0:
96
+ raise ValueError("Hidden size must be divisible by the number of heads.")
97
+ self.dim_head = hidden_size // num_attention_heads
98
+ self.intermediate_size = intermediate_size
99
+ self.embedding_init_range = embedding_init_range
100
+ self.decoder_init_range = decoder_init_range
101
+ self.norm_eps = norm_eps
102
+ self.vocab_size = vocab_size
103
+ self.pad_token_id = pad_token_id
104
+ self.max_length = max_length
105
+ self.kwargs = kwargs
106
+
107
+
108
+ # Adapted from transformers.models.llama.modeling_llama.LlamaMLP
109
+ class NeobertMLP(nn.Module):
110
+ def __init__(self, hidden_size, intermediate_size, bias=False):
111
+ super().__init__()
112
+ self.hidden_size = hidden_size
113
+ self.intermediate_size = intermediate_size
114
+ self.w12 = nn.Linear(self.hidden_size, 2 * self.intermediate_size, bias=bias)
115
+ self.w3 = nn.Linear(self.intermediate_size, self.hidden_size, bias=bias)
116
+ self.act_fn = nn.SiLU()
117
+
118
+ def forward(self, x):
119
+ w1, w2 = self.w12(x).chunk(2, dim=-1)
120
+ w3 = self.w3(self.act_fn(w1) * w2)
121
+ return w3
122
+
123
+
124
+ class EncoderBlock(nn.Module):
125
+ """Transformer encoder block."""
126
+
127
+ def __init__(self, config: NeoBERTConfig):
128
+ super().__init__()
129
+
130
+ self.config = config
131
+
132
+ # Attention
133
+ self.qkv = nn.Linear(
134
+ in_features=config.hidden_size,
135
+ out_features=config.hidden_size * 3,
136
+ bias=False,
137
+ )
138
+ self.wo = nn.Linear(
139
+ in_features=config.hidden_size, out_features=config.hidden_size, bias=False
140
+ )
141
+
142
+ # Feedforward network
143
+ multiple_of = 8
144
+ intermediate_size = int(2 * config.intermediate_size / 3)
145
+ intermediate_size = multiple_of * (
146
+ (intermediate_size + multiple_of - 1) // multiple_of
147
+ )
148
+ if XFORMERS_AVAILABLE:
149
+ self.ffn = SwiGLU(
150
+ config.hidden_size, intermediate_size, config.hidden_size, bias=False
151
+ )
152
+ else:
153
+ self.ffn = NeobertMLP(config.hidden_size, intermediate_size, bias=False)
154
+
155
+ # Layer norms
156
+ self.attention_norm = nn.RMSNorm(config.hidden_size, config.norm_eps)
157
+ self.ffn_norm = nn.RMSNorm(config.hidden_size, config.norm_eps)
158
+
159
+ def forward(
160
+ self,
161
+ x: torch.Tensor,
162
+ attention_mask: torch.Tensor,
163
+ freqs_cis: torch.Tensor,
164
+ output_attentions: bool,
165
+ max_seqlen: int = None,
166
+ cu_seqlens: torch.Tensor = None,
167
+ ):
168
+ # Attention
169
+ attn_output, attn_weights = self._att_block(
170
+ self.attention_norm(x),
171
+ attention_mask,
172
+ freqs_cis,
173
+ output_attentions,
174
+ max_seqlen,
175
+ cu_seqlens,
176
+ )
177
+
178
+ # Residual
179
+ x = x + attn_output
180
+
181
+ # Feed-forward
182
+ x = x + self.ffn(self.ffn_norm(x))
183
+
184
+ return x, attn_weights
185
+
186
+ def _att_block(
187
+ self,
188
+ x: torch.Tensor,
189
+ attention_mask: torch.Tensor,
190
+ freqs_cis: torch.Tensor,
191
+ output_attentions: bool,
192
+ max_seqlen: int = None,
193
+ cu_seqlens: torch.Tensor = None,
194
+ ):
195
+ batch_size, seq_len, _ = x.shape
196
+
197
+ xq, xk, xv = (
198
+ self.qkv(x)
199
+ .view(
200
+ batch_size,
201
+ seq_len,
202
+ self.config.num_attention_heads,
203
+ self.config.dim_head * 3,
204
+ )
205
+ .chunk(3, axis=-1)
206
+ )
207
+
208
+ xq, xk = apply_rotary_emb(xq, xk, freqs_cis)
209
+
210
+ # Attn block
211
+ attn_weights = None
212
+
213
+ # Flash attention if the tensors are packed
214
+ if cu_seqlens is not None:
215
+ attn = flash_attn_varlen_func(
216
+ q=xq.squeeze(0),
217
+ k=xk.squeeze(0),
218
+ v=xv.squeeze(0),
219
+ cu_seqlens_q=cu_seqlens,
220
+ cu_seqlens_k=cu_seqlens,
221
+ max_seqlen_q=max_seqlen,
222
+ max_seqlen_k=max_seqlen,
223
+ dropout_p=0.0,
224
+ causal=False,
225
+ )
226
+ # Eager attention if attention weights are needed in the output
227
+ elif output_attentions:
228
+ attn_weights = (
229
+ xq.permute(0, 2, 1, 3) @ xk.permute(0, 2, 3, 1) / (xq.size(-1) ** 0.5)
230
+ )
231
+ if attention_mask is not None:
232
+ attn_weights = attn_weights * attention_mask
233
+ attn_weights = attn_weights.softmax(-1)
234
+ attn = attn_weights @ xv.permute(0, 2, 1, 3)
235
+ attn = attn.transpose(1, 2)
236
+ # Fall back to SDPA otherwise
237
+ else:
238
+ attn = scaled_dot_product_attention(
239
+ query=xq.transpose(1, 2),
240
+ key=xk.transpose(1, 2),
241
+ value=xv.transpose(1, 2),
242
+ attn_mask=attention_mask.bool() if attention_mask is not None else None,
243
+ dropout_p=0,
244
+ ).transpose(1, 2)
245
+
246
+ return (
247
+ self.wo(
248
+ attn.reshape(
249
+ batch_size,
250
+ seq_len,
251
+ self.config.num_attention_heads * self.config.dim_head,
252
+ )
253
+ ),
254
+ attn_weights,
255
+ )
256
+
257
+
258
+ class NeoBERTPreTrainedModel(PreTrainedModel):
259
+ config_class = NeoBERTConfig
260
+ base_model_prefix = "model"
261
+ _supports_cache_class = True
262
+
263
+ def _init_weights(self, module):
264
+ if isinstance(module, nn.Linear):
265
+ module.weight.data.uniform_(
266
+ -self.config.decoder_init_range, self.config.decoder_init_range
267
+ )
268
+ elif isinstance(module, nn.Embedding):
269
+ module.weight.data.uniform_(
270
+ -self.config.embedding_init_range, self.config.embedding_init_range
271
+ )
272
+
273
+
274
+ class NeoBERT(NeoBERTPreTrainedModel):
275
+ config_class = NeoBERTConfig
276
+
277
+ def __init__(self, config: NeoBERTConfig):
278
+ super().__init__(config)
279
+
280
+ self.config = config
281
+
282
+ self.encoder = nn.Embedding(
283
+ config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
284
+ )
285
+
286
+ # Ensures freqs_cis is moved to the same devices as the model. Non-persistent buffers are not saved in the state_dict.
287
+ freqs_cis = precompute_freqs_cis(
288
+ config.hidden_size // config.num_attention_heads, config.max_length
289
+ )
290
+ self.register_buffer("freqs_cis", freqs_cis, persistent=False)
291
+
292
+ self.transformer_encoder = nn.ModuleList()
293
+ for _ in range(config.num_hidden_layers):
294
+ self.transformer_encoder.append(EncoderBlock(config))
295
+
296
+ self.layer_norm = nn.RMSNorm(config.hidden_size, config.norm_eps)
297
+
298
+ # Initialize weights and apply final processing
299
+ self.post_init()
300
+
301
+ def forward(
302
+ self,
303
+ input_ids: torch.Tensor,
304
+ position_ids: torch.Tensor = None,
305
+ max_seqlen: int = None,
306
+ cu_seqlens: torch.Tensor = None,
307
+ attention_mask: torch.Tensor = None,
308
+ output_hidden_states: bool = False,
309
+ output_attentions: bool = False,
310
+ **kwargs,
311
+ ):
312
+ # Initialize
313
+ hidden_states, attentions = [], []
314
+
315
+ # Expand and repeat: (Batch, Length) -> (Batch, Heads, Length, Length)
316
+ if attention_mask is not None:
317
+ attention_mask = (
318
+ attention_mask.unsqueeze(1)
319
+ .unsqueeze(1)
320
+ .repeat(1, self.config.num_attention_heads, attention_mask.size(-1), 1)
321
+ )
322
+
323
+ # Checks to be done if inputs are packed sequences
324
+ if cu_seqlens is not None:
325
+ assert FLASH_ATTN_AVAILABLE, (
326
+ "Flash-attention is not available. Please ''pip install flash_attn'', or provide un-packed sequences."
327
+ )
328
+ assert not output_attentions, (
329
+ "Output attentions is not supported when sequences are packed."
330
+ )
331
+ assert max_seqlen is not None, (
332
+ "Missing max_seqlen. It must be provided when cu_seqlens are not None."
333
+ )
334
+ assert input_ids.shape[0] == 1, (
335
+ "Cumulative sequence lengths are provided but input_ids are not packed."
336
+ )
337
+ assert input_ids.is_cuda, (
338
+ "Packing uses an implementation of flash-attention and is only supported on GPU."
339
+ )
340
+
341
+ # RoPE
342
+ freqs_cis = (
343
+ self.freqs_cis[position_ids]
344
+ if position_ids is not None
345
+ else self.freqs_cis[: input_ids.shape[1]].unsqueeze(0)
346
+ )
347
+
348
+ # Embedding
349
+ x = self.encoder(input_ids)
350
+
351
+ # Transformer encoder
352
+ for layer in self.transformer_encoder:
353
+ x, attn = layer(
354
+ x, attention_mask, freqs_cis, output_attentions, max_seqlen, cu_seqlens
355
+ )
356
+ if output_hidden_states:
357
+ hidden_states.append(x)
358
+ if output_attentions:
359
+ attentions.append(attn)
360
+
361
+ # Final normalization layer
362
+ x = self.layer_norm(x)
363
+
364
+ # Return the output of the last hidden layer
365
+ return BaseModelOutput(
366
+ last_hidden_state=x,
367
+ hidden_states=hidden_states if output_hidden_states else None,
368
+ attentions=attentions if output_attentions else None,
369
+ )
370
+
371
+
372
+ class NeoBERTLMHead(NeoBERTPreTrainedModel):
373
+ config_class = NeoBERTConfig
374
+
375
+ def __init__(self, config: NeoBERTConfig):
376
+ super().__init__(config)
377
+
378
+ self.config = config
379
+
380
+ self.model = NeoBERT(config)
381
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
382
+
383
+ self.post_init()
384
+
385
+ def forward(
386
+ self,
387
+ input_ids: torch.Tensor,
388
+ position_ids: torch.Tensor = None,
389
+ max_seqlen: int = None,
390
+ cu_seqlens: torch.Tensor = None,
391
+ attention_mask: torch.Tensor = None,
392
+ output_hidden_states: bool = False,
393
+ output_attentions: bool = False,
394
+ **kwargs,
395
+ ):
396
+ output = self.model.forward(
397
+ input_ids,
398
+ position_ids,
399
+ max_seqlen,
400
+ cu_seqlens,
401
+ attention_mask,
402
+ output_hidden_states,
403
+ output_attentions,
404
+ )
405
+ logits = self.decoder(output.last_hidden_state)
406
+
407
+ return MaskedLMOutput(
408
+ hidden_states=output.hidden_states if output_hidden_states else None,
409
+ attentions=output.attentions if output_attentions else None,
410
+ logits=logits,
411
+ )
412
+
413
+
414
+ class NeoBERTForSequenceClassification(NeoBERTPreTrainedModel):
415
+ config_class = NeoBERTConfig
416
+
417
+ def __init__(self, config: NeoBERTConfig):
418
+ super().__init__(config)
419
+
420
+ self.config = config
421
+
422
+ self.num_labels = getattr(config, "num_labels", 2)
423
+ self.classifier_dropout = getattr(config, "classifier_dropout", 0.1)
424
+ self.classifier_init_range = getattr(config, "classifier_init_range", 0.02)
425
+
426
+ self.model = NeoBERT(config)
427
+
428
+ self.dense = nn.Linear(self.config.hidden_size, self.config.hidden_size)
429
+ self.dropout = nn.Dropout(self.classifier_dropout)
430
+ self.classifier = nn.Linear(self.config.hidden_size, self.num_labels)
431
+
432
+ self.post_init()
433
+
434
+ def _init_weights(self, module):
435
+ if isinstance(module, nn.Linear):
436
+ module.weight.data.normal_(mean=0.0, std=self.classifier_init_range)
437
+ if module.bias is not None:
438
+ module.bias.data.zero_()
439
+
440
+ def forward(
441
+ self,
442
+ input_ids: torch.Tensor,
443
+ position_ids: torch.Tensor = None,
444
+ max_seqlen: int = None,
445
+ cu_seqlens: torch.Tensor = None,
446
+ attention_mask: torch.Tensor = None,
447
+ output_hidden_states: bool = False,
448
+ output_attentions: bool = False,
449
+ labels: Optional[torch.Tensor] = None,
450
+ return_dict: Optional[bool] = None,
451
+ ):
452
+ output = self.model.forward(
453
+ input_ids,
454
+ position_ids,
455
+ max_seqlen,
456
+ cu_seqlens,
457
+ attention_mask,
458
+ output_hidden_states,
459
+ output_attentions,
460
+ )
461
+ hidden_states = output.last_hidden_state
462
+
463
+ x = hidden_states[:, 0, :]
464
+ x = self.dropout(x)
465
+ x = self.dense(x)
466
+ x = torch.tanh(x)
467
+ x = self.dropout(x)
468
+
469
+ logits = self.classifier(x)
470
+
471
+ loss = None
472
+ if labels is not None:
473
+ if self.config.problem_type is None:
474
+ if self.num_labels == 1:
475
+ self.config.problem_type = "regression"
476
+ elif self.num_labels > 1 and (
477
+ labels.dtype == torch.long or labels.dtype == torch.int
478
+ ):
479
+ self.config.problem_type = "single_label_classification"
480
+ else:
481
+ self.config.problem_type = "multi_label_classification"
482
+
483
+ if self.config.problem_type == "regression":
484
+ loss_fct = MSELoss()
485
+ if self.num_labels == 1:
486
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
487
+ else:
488
+ loss = loss_fct(logits, labels)
489
+ elif self.config.problem_type == "single_label_classification":
490
+ loss_fct = CrossEntropyLoss()
491
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
492
+ elif self.config.problem_type == "multi_label_classification":
493
+ loss_fct = BCEWithLogitsLoss()
494
+ loss = loss_fct(logits, labels)
495
+
496
+ if not return_dict:
497
+ result = (logits,)
498
+ return ((loss,) + result) if loss is not None else result
499
+
500
+ return SequenceClassifierOutput(
501
+ loss=loss,
502
+ logits=logits,
503
+ hidden_states=output.hidden_states if output_hidden_states else None,
504
+ attentions=output.attentions if output_attentions else None,
505
+ )
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:33b6bf91a447ce603db6ae080de1a67dde749fd867f471029a22a03cbdfa9eab
3
+ size 536553948
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d71fc827b1bce4bf6201186eaea64c84a903f567a27b106f43ae2adc59fc250c
3
+ size 536570475
rotary.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # From https://github.com/facebookresearch/llama/blob/main/llama/model.py
2
+
3
+ from typing import Tuple
4
+
5
+ import torch
6
+
7
+
8
+ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
9
+ """
10
+ Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
11
+
12
+ This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
13
+ and the end index 'end'. The 'theta' parameter scales the frequencies.
14
+ The returned tensor contains complex values in complex64 data type.
15
+
16
+ Args:
17
+ dim (int): Dimension of the frequency tensor.
18
+ end (int): End index for precomputing frequencies.
19
+ theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
20
+
21
+ Returns:
22
+ torch.Tensor: Precomputed frequency tensor with complex exponentials.
23
+ """
24
+
25
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
26
+ t = torch.arange(end, device=freqs.device)
27
+ freqs = torch.outer(t, freqs).float()
28
+ return torch.polar(torch.ones_like(freqs), freqs)
29
+
30
+
31
+ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
32
+ assert freqs_cis.shape[1:] == (x.shape[1], x.shape[-1])
33
+ return freqs_cis.contiguous().unsqueeze(2)
34
+
35
+
36
+ def apply_rotary_emb(
37
+ xq: torch.Tensor,
38
+ xk: torch.Tensor,
39
+ freqs_cis: torch.Tensor,
40
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
41
+ """
42
+ Apply rotary embeddings to input tensors using the given frequency tensor.
43
+
44
+ This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
45
+ frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
46
+ is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
47
+ returned as real tensors.
48
+
49
+ Args:
50
+ xq (torch.Tensor): Query tensor to apply rotary embeddings.
51
+ xk (torch.Tensor): Key tensor to apply rotary embeddings.
52
+ freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials.
53
+
54
+ Returns:
55
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
56
+ """
57
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
58
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
59
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
60
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
61
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
62
+ return xq_out.type_as(xq), xk_out.type_as(xk)
special_tokens_map.json ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "[CLS]",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "cls_token": {
10
+ "content": "[CLS]",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "eos_token": {
17
+ "content": "[SEP]",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "mask_token": {
24
+ "content": "[MASK]",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ },
30
+ "pad_token": {
31
+ "content": "[PAD]",
32
+ "lstrip": false,
33
+ "normalized": false,
34
+ "rstrip": false,
35
+ "single_word": false
36
+ },
37
+ "sep_token": {
38
+ "content": "[SEP]",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false
43
+ },
44
+ "unk_token": {
45
+ "content": "[UNK]",
46
+ "lstrip": false,
47
+ "normalized": false,
48
+ "rstrip": false,
49
+ "single_word": false
50
+ }
51
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "[UNK]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "1": {
12
+ "content": "[CLS]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "2": {
20
+ "content": "[SEP]",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "3": {
28
+ "content": "[PAD]",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "4": {
36
+ "content": "[MASK]",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ }
43
+ },
44
+ "bos_token": "[CLS]",
45
+ "clean_up_tokenization_spaces": true,
46
+ "cls_token": "[CLS]",
47
+ "eos_token": "[SEP]",
48
+ "extra_special_tokens": {},
49
+ "mask_token": "[MASK]",
50
+ "max_length": 1024,
51
+ "model_max_length": 4096,
52
+ "pad_token": "[PAD]",
53
+ "sep_token": "[SEP]",
54
+ "stride": 0,
55
+ "tokenizer_class": "PreTrainedTokenizerFast",
56
+ "truncation_side": "right",
57
+ "truncation_strategy": "longest_first",
58
+ "unk_token": "[UNK]",
59
+ "vocab_size": 31999
60
+ }