ashercn97 commited on
Commit
fc2f01d
·
verified ·
1 Parent(s): 3b8c66c

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +514 -0
model.py ADDED
@@ -0,0 +1,514 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # From https://github.com/facebookresearch/llama/blob/main/llama/model.py
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
7
+ from torch.nn.functional import scaled_dot_product_attention
8
+
9
+ from typing import Optional, Tuple
10
+ import numpy as np
11
+
12
+
13
+ try:
14
+ from flash_attn.flash_attn_interface import flash_attn_varlen_func
15
+
16
+ FLASH_ATTN_AVAILABLE = True
17
+ print("USE FLASH ATTN")
18
+ except ImportError:
19
+ FLASH_ATTN_AVAILABLE = False
20
+
21
+ from transformers import (
22
+ PreTrainedModel,
23
+ PretrainedConfig,
24
+ DataCollatorForLanguageModeling,
25
+ )
26
+ from transformers.modeling_outputs import (
27
+ BaseModelOutput,
28
+ MaskedLMOutput,
29
+ SequenceClassifierOutput,
30
+ )
31
+
32
+ from .rotary import precompute_freqs_cis, apply_rotary_emb
33
+ import torch.nn.functional as F
34
+
35
+
36
+ class SwiGLU(nn.Module):
37
+ def __init__(self, input_dim: int, hidden_dim: int = None, bias: bool = True):
38
+ super().__init__()
39
+ hidden_dim = hidden_dim or input_dim * 2
40
+ self.linear = nn.Linear(input_dim, hidden_dim * 2, bias=bias)
41
+ self.output_proj = nn.Linear(hidden_dim, input_dim, bias=bias)
42
+
43
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
44
+ x_proj = self.linear(x)
45
+ x1, x2 = x_proj.chunk(2, dim=-1)
46
+ x = x1 * F.silu(x2) # SwiGLU activation
47
+ return self.output_proj(x)
48
+
49
+
50
+ class DataCollatorWithPacking(DataCollatorForLanguageModeling):
51
+ def __init__(self, pack_sequences=False, **kwargs):
52
+ super().__init__(**kwargs)
53
+ self.pack_sequences = pack_sequences
54
+
55
+ def __call__(self, batch):
56
+ if self.pack_sequences:
57
+ # Add position_ids if not present
58
+ if "position_ids" not in batch[0]:
59
+ for item in batch:
60
+ item["position_ids"] = list(range(len(item["input_ids"])))
61
+
62
+ # Pack the sequences into a single list
63
+ input_ids_list = [item["input_ids"] for item in batch]
64
+ position_ids_list = [item["position_ids"] for item in batch]
65
+ seqlens = np.array([0] + [len(ids) for ids in input_ids_list])
66
+
67
+ packed_batch = {
68
+ "position_ids": np.concatenate(position_ids_list, axis=0),
69
+ "input_ids": np.concatenate(input_ids_list, axis=0),
70
+ "cu_seqlens": np.cumsum(seqlens),
71
+ "max_seqlen": max(seqlens),
72
+ }
73
+
74
+ batch = super().__call__([packed_batch])
75
+ batch["cu_seqlens"] = batch["cu_seqlens"].to(torch.int32).squeeze()
76
+ else:
77
+ batch = super().__call__(batch)
78
+ batch["attention_mask"] = batch["attention_mask"].to(torch.bool)
79
+
80
+ return batch
81
+
82
+
83
+ class NeoBERTConfig(PretrainedConfig):
84
+ model_type = "neobert"
85
+
86
+ # All config parameters must have a default value.
87
+ def __init__(
88
+ self,
89
+ hidden_size: int = 768,
90
+ num_hidden_layers: int = 28,
91
+ num_attention_heads: int = 12,
92
+ intermediate_size: int = 3072,
93
+ embedding_init_range: float = 0.02,
94
+ decoder_init_range: float = 0.02,
95
+ norm_eps: float = 1e-06,
96
+ vocab_size: int = 30522,
97
+ pad_token_id: int = 0,
98
+ max_length: int = 1024,
99
+ **kwargs,
100
+ ):
101
+ super().__init__(**kwargs)
102
+
103
+ self.hidden_size = hidden_size
104
+ self.num_hidden_layers = num_hidden_layers
105
+ self.num_attention_heads = num_attention_heads
106
+ if hidden_size % num_attention_heads != 0:
107
+ raise ValueError("Hidden size must be divisible by the number of heads.")
108
+ self.dim_head = hidden_size // num_attention_heads
109
+ self.intermediate_size = intermediate_size
110
+ self.embedding_init_range = embedding_init_range
111
+ self.decoder_init_range = decoder_init_range
112
+ self.norm_eps = norm_eps
113
+ self.vocab_size = vocab_size
114
+ self.pad_token_id = pad_token_id
115
+ self.max_length = max_length
116
+ self.kwargs = kwargs
117
+
118
+
119
+ class EncoderBlock(nn.Module):
120
+ """Transformer encoder block."""
121
+
122
+ def __init__(self, config: NeoBERTConfig):
123
+ super().__init__()
124
+
125
+ self.config = config
126
+
127
+ # Attention
128
+ self.qkv = nn.Linear(in_features=config.hidden_size, out_features=config.hidden_size * 3, bias=False)
129
+ self.wo = nn.Linear(in_features=config.hidden_size, out_features=config.hidden_size, bias=False)
130
+
131
+ # Feedforward network
132
+ multiple_of = 8
133
+ intermediate_size = int(2 * config.intermediate_size / 3)
134
+ intermediate_size = multiple_of * ((intermediate_size + multiple_of - 1) // multiple_of)
135
+ self.ffn = SwiGLU(config.hidden_size, intermediate_size, config.hidden_size)
136
+
137
+ # Layer norms
138
+ self.attention_norm = nn.RMSNorm(config.hidden_size, config.norm_eps)
139
+ self.ffn_norm = nn.RMSNorm(config.hidden_size, config.norm_eps)
140
+
141
+ def forward(
142
+ self,
143
+ x: torch.Tensor,
144
+ attention_mask: torch.Tensor,
145
+ freqs_cis: torch.Tensor,
146
+ output_attentions: bool,
147
+ max_seqlen: int = None,
148
+ cu_seqlens: torch.Tensor = None,
149
+ ):
150
+ # Attention
151
+ attn_output, attn_weights = self._att_block(
152
+ self.attention_norm(x), attention_mask, freqs_cis, output_attentions, max_seqlen, cu_seqlens
153
+ )
154
+
155
+ # Residual
156
+ x = x + attn_output
157
+
158
+ # Feed-forward
159
+ x = x + self.ffn(self.ffn_norm(x))
160
+
161
+ return x, attn_weights
162
+
163
+ def _att_block(
164
+ self,
165
+ x: torch.Tensor,
166
+ attention_mask: torch.Tensor,
167
+ freqs_cis: torch.Tensor,
168
+ output_attentions: bool,
169
+ max_seqlen: int = None,
170
+ cu_seqlens: torch.Tensor = None,
171
+ ):
172
+ batch_size, seq_len, _ = x.shape
173
+
174
+ xq, xk, xv = self.qkv(x).view(batch_size, seq_len, self.config.num_attention_heads, self.config.dim_head * 3).chunk(3, axis=-1)
175
+
176
+ xq, xk = apply_rotary_emb(xq, xk, freqs_cis)
177
+
178
+ # Attn block
179
+ attn_weights = None
180
+
181
+ # Flash attention if the tensors are packed
182
+ if cu_seqlens is not None:
183
+ attn = flash_attn_varlen_func(
184
+ q=xq.squeeze(0),
185
+ k=xk.squeeze(0),
186
+ v=xv.squeeze(0),
187
+ cu_seqlens_q=cu_seqlens,
188
+ cu_seqlens_k=cu_seqlens,
189
+ max_seqlen_q=max_seqlen,
190
+ max_seqlen_k=max_seqlen,
191
+ dropout_p=0.0,
192
+ causal=False,
193
+ )
194
+ # Eager attention if attention weights are needed in the output
195
+ elif output_attentions:
196
+ attn_weights = xq.permute(0, 2, 1, 3) @ xk.permute(0, 2, 3, 1) / (xq.size(-1) ** 0.5)
197
+ if attention_mask is not None:
198
+ attn_weights = attn_weights * attention_mask
199
+ attn_weights = attn_weights.softmax(-1)
200
+ attn = attn_weights @ xv.permute(0, 2, 1, 3)
201
+ attn = attn.transpose(1, 2)
202
+ # Fall back to SDPA otherwise
203
+ else:
204
+ attn = scaled_dot_product_attention(
205
+ query=xq.transpose(1, 2),
206
+ key=xk.transpose(1, 2),
207
+ value=xv.transpose(1, 2),
208
+ attn_mask=attention_mask.bool(),
209
+ dropout_p=0,
210
+ ).transpose(1, 2)
211
+
212
+ return self.wo(attn.reshape(batch_size, seq_len, self.config.num_attention_heads * self.config.dim_head)), attn_weights
213
+
214
+
215
+ class NeoBERTPreTrainedModel(PreTrainedModel):
216
+ config_class = NeoBERTConfig
217
+ base_model_prefix = "model"
218
+ _supports_cache_class = True
219
+
220
+ def _init_weights(self, module):
221
+ if isinstance(module, nn.Linear):
222
+ module.weight.data.uniform_(-self.config.decoder_init_range, self.config.decoder_init_range)
223
+ elif isinstance(module, nn.Embedding):
224
+ module.weight.data.uniform_(-self.config.embedding_init_range, self.config.embedding_init_range)
225
+
226
+
227
+ class NeoBERT(NeoBERTPreTrainedModel):
228
+ config_class = NeoBERTConfig
229
+
230
+ def __init__(self, config: NeoBERTConfig):
231
+ super().__init__(config)
232
+ self.output_hidden_states = True
233
+
234
+ self.config = config
235
+
236
+ self.encoder = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
237
+
238
+ # Ensures freqs_cis is moved to the same devices as the model. Non-persistent buffers are not saved in the state_dict.
239
+ freqs_cis = precompute_freqs_cis(config.hidden_size // config.num_attention_heads, config.max_length)
240
+ self.register_buffer("freqs_cis", freqs_cis, persistent=False)
241
+
242
+ self.transformer_encoder = nn.ModuleList()
243
+ for _ in range(config.num_hidden_layers):
244
+ self.transformer_encoder.append(EncoderBlock(config))
245
+
246
+ self.layer_norm = nn.RMSNorm(config.hidden_size, config.norm_eps)
247
+
248
+ # Initialize weights and apply final processing
249
+ self.post_init()
250
+
251
+ def forward(
252
+ self,
253
+ input_ids: Optional[torch.Tensor] = None,
254
+ position_ids: torch.Tensor = None,
255
+ max_seqlen: int = None,
256
+ cu_seqlens: torch.Tensor = None,
257
+ attention_mask: torch.Tensor = None,
258
+ inputs_embeds: Optional[torch.Tensor] = None,
259
+ output_hidden_states: bool = False,
260
+ output_attentions: bool = False,
261
+ **kwargs,
262
+ ):
263
+ # Initialize
264
+ hidden_states, attentions = [], []
265
+
266
+ if (input_ids is None) ^ (inputs_embeds is not None):
267
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
268
+
269
+ # Expand and repeat: (Batch, Length) -> (Batch, Heads, Length, Length)
270
+ if attention_mask is not None:
271
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(1).repeat(1, self.config.num_attention_heads, attention_mask.size(-1), 1)
272
+
273
+ # Checks to be done if inputs are packed sequences
274
+ if cu_seqlens is not None:
275
+ assert (
276
+ FLASH_ATTN_AVAILABLE
277
+ ), "Flash-attention is not available. Please ''pip install flash_attn'', or provide un-packed sequences."
278
+ assert not output_attentions, "Output attentions is not supported when sequences are packed."
279
+ assert max_seqlen is not None, "Missing max_seqlen. It must be provided when cu_seqlens are not None."
280
+ assert (input_ids if input_ids is not None else inputs_embeds).shape[
281
+ 0
282
+ ] == 1, "Cumulative sequence lengths are provided but inputs are not packed."
283
+ assert (
284
+ input_ids if input_ids is not None else inputs_embeds
285
+ ).is_cuda, "Packing uses an implementation of flash-attention and is only supported on GPU."
286
+
287
+ # RoPE
288
+ freqs_cis = (
289
+ self.freqs_cis[position_ids]
290
+ if position_ids is not None
291
+ else self.freqs_cis[: (input_ids if input_ids is not None else inputs_embeds).shape[1]].unsqueeze(0)
292
+ )
293
+
294
+ # Embedding
295
+ if input_ids is not None:
296
+ input_ids = input_ids.long() # Ensure correct dtype
297
+ x = self.encoder(input_ids)
298
+ else:
299
+ x = inputs_embeds
300
+
301
+ # ⬇️ ADD THIS LINE to capture the embedding output
302
+ if output_hidden_states:
303
+ hidden_states.append(x)
304
+
305
+ # Transformer encoder
306
+ for layer in self.transformer_encoder:
307
+ x, attn = layer(x, attention_mask, freqs_cis, output_attentions, max_seqlen, cu_seqlens)
308
+ if output_hidden_states:
309
+ hidden_states.append(x)
310
+ if output_attentions:
311
+ attentions.append(attn)
312
+
313
+ # Final normalization layer
314
+ x = self.layer_norm(x)
315
+
316
+ # Return the output of the last hidden layer
317
+ return BaseModelOutput(
318
+ last_hidden_state=x,
319
+ hidden_states=hidden_states,
320
+ attentions=attentions if output_attentions else None,
321
+ )
322
+
323
+
324
+ # class NeoBERTLMHead(NeoBERTPreTrainedModel):
325
+ # config_class = NeoBERTConfig
326
+
327
+ # def __init__(self, config: NeoBERTConfig):
328
+ # super().__init__(config)
329
+
330
+ # self.config = config
331
+
332
+ # self.model = NeoBERT(config)
333
+ # self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
334
+
335
+ # self.post_init()
336
+
337
+ # def forward(
338
+ # self,
339
+ # input_ids: torch.Tensor,
340
+ # position_ids: torch.Tensor = None,
341
+ # max_seqlen: int = None,
342
+ # cu_seqlens: torch.Tensor = None,
343
+ # attention_mask: torch.Tensor = None,
344
+ # output_hidden_states: bool = False,
345
+ # output_attentions: bool = False,
346
+ # **kwargs,
347
+ # ):
348
+
349
+ # output = self.model.forward(
350
+ # input_ids=input_ids,
351
+ # position_ids=position_ids,
352
+ # max_seqlen=max_seqlen,
353
+ # cu_seqlens=cu_seqlens,
354
+ # attention_mask=attention_mask,
355
+ # output_hidden_states=output_hidden_states,
356
+ # output_attentions=output_attentions,
357
+ # )
358
+ # logits = self.decoder(output.last_hidden_state)
359
+
360
+ # return MaskedLMOutput(
361
+ # hidden_states=output.hidden_states if output_hidden_states else None,
362
+ # attentions=output.attentions if output_attentions else None,
363
+ # logits=logits,
364
+ # )
365
+
366
+
367
+ import torch.nn.functional as F
368
+ from transformers.modeling_outputs import MaskedLMOutput
369
+
370
+ class NeoBERTLMHead(NeoBERTPreTrainedModel):
371
+ config_class = NeoBERTConfig
372
+
373
+ def __init__(self, config: NeoBERTConfig):
374
+ super().__init__(config)
375
+
376
+ self.config = config
377
+
378
+ self.model = NeoBERT(config)
379
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
380
+ self.decoder.weight = self.model.encoder.weight
381
+ self.post_init()
382
+
383
+ def forward(
384
+ self,
385
+ input_ids: torch.Tensor,
386
+ position_ids: torch.Tensor = None,
387
+ max_seqlen: int = None,
388
+ cu_seqlens: torch.Tensor = None,
389
+ attention_mask: torch.Tensor = None,
390
+ labels: torch.Tensor = None,
391
+ output_hidden_states: bool = False,
392
+ output_attentions: bool = False,
393
+ **kwargs,
394
+ ):
395
+ output = self.model.forward(
396
+ input_ids=input_ids,
397
+ position_ids=position_ids,
398
+ max_seqlen=max_seqlen,
399
+ cu_seqlens=cu_seqlens,
400
+ attention_mask=attention_mask,
401
+ output_hidden_states=output_hidden_states,
402
+ output_attentions=output_attentions,
403
+ )
404
+ logits = self.decoder(output.last_hidden_state)
405
+
406
+ loss = None
407
+ if labels is not None:
408
+ # Shape: (batch, seq_len, vocab_size) => (batch * seq_len, vocab_size)
409
+ # labels: (batch, seq_len) => (batch * seq_len)
410
+ loss = F.cross_entropy(
411
+ logits.view(-1, logits.size(-1)),
412
+ labels.view(-1),
413
+ ignore_index=-100 # this matches what your metrics are using
414
+ )
415
+
416
+ return MaskedLMOutput(
417
+ loss=loss,
418
+ logits=logits,
419
+ hidden_states=output.hidden_states if output_hidden_states else None,
420
+ attentions=output.attentions if output_attentions else None,
421
+ )
422
+
423
+ class NeoBERTForSequenceClassification(NeoBERTPreTrainedModel):
424
+ config_class = NeoBERTConfig
425
+
426
+ def __init__(self, config: NeoBERTConfig):
427
+ super().__init__(config)
428
+
429
+ self.config = config
430
+
431
+ self.num_labels = getattr(config, "num_labels", 2)
432
+ self.classifier_dropout = getattr(config, "classifier_dropout", 0.1)
433
+ self.classifier_init_range = getattr(config, "classifier_init_range", 0.02)
434
+
435
+ self.model = NeoBERT(config)
436
+
437
+ self.dense = nn.Linear(self.config.hidden_size, self.config.hidden_size)
438
+ self.dropout = nn.Dropout(self.classifier_dropout)
439
+ self.classifier = nn.Linear(self.config.hidden_size, self.num_labels)
440
+
441
+ self.post_init()
442
+
443
+ def _init_weights(self, module):
444
+ if isinstance(module, nn.Linear):
445
+ module.weight.data.normal_(mean=0.0, std=self.classifier_init_range)
446
+ if module.bias is not None:
447
+ module.bias.data.zero_()
448
+
449
+ def forward(
450
+ self,
451
+ input_ids: Optional[torch.Tensor] = None,
452
+ position_ids: torch.Tensor = None,
453
+ max_seqlen: int = None,
454
+ cu_seqlens: torch.Tensor = None,
455
+ attention_mask: torch.Tensor = None,
456
+ output_hidden_states: bool = False,
457
+ output_attentions: bool = False,
458
+ labels: Optional[torch.Tensor] = None,
459
+ return_dict: Optional[bool] = None,
460
+ ):
461
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
462
+
463
+ output = self.model.forward(
464
+ input_ids=input_ids,
465
+ position_ids=position_ids,
466
+ max_seqlen=max_seqlen,
467
+ cu_seqlens=cu_seqlens,
468
+ attention_mask=attention_mask,
469
+ output_hidden_states=output_hidden_states,
470
+ output_attentions=output_attentions,
471
+ )
472
+ hidden_states = output.last_hidden_state
473
+
474
+ x = hidden_states[:, 0, :]
475
+ x = self.dropout(x)
476
+ x = self.dense(x)
477
+ x = torch.tanh(x)
478
+ x = self.dropout(x)
479
+
480
+ logits = self.classifier(x)
481
+
482
+ loss = None
483
+ if labels is not None:
484
+ if self.config.problem_type is None:
485
+ if self.num_labels == 1:
486
+ self.config.problem_type = "regression"
487
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
488
+ self.config.problem_type = "single_label_classification"
489
+ else:
490
+ self.config.problem_type = "multi_label_classification"
491
+
492
+ if self.config.problem_type == "regression":
493
+ loss_fct = MSELoss()
494
+ if self.num_labels == 1:
495
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
496
+ else:
497
+ loss = loss_fct(logits, labels)
498
+ elif self.config.problem_type == "single_label_classification":
499
+ loss_fct = CrossEntropyLoss()
500
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
501
+ elif self.config.problem_type == "multi_label_classification":
502
+ loss_fct = BCEWithLogitsLoss()
503
+ loss = loss_fct(logits, labels)
504
+
505
+ if not return_dict:
506
+ result = (logits,)
507
+ return ((loss,) + result) if loss is not None else result
508
+
509
+ return SequenceClassifierOutput(
510
+ loss=loss,
511
+ logits=logits,
512
+ hidden_states=output.hidden_states if output_hidden_states else None,
513
+ attentions=output.attentions if output_attentions else None,
514
+ )