MariaFjodorowa commited on
Commit
e78f45d
·
verified ·
1 Parent(s): 6287598

fix NaNs and output format

Browse files
Files changed (1) hide show
  1. modeling_gptbert.py +138 -146
modeling_gptbert.py CHANGED
@@ -25,7 +25,6 @@ from typing import TYPE_CHECKING, Optional, Union, Tuple, List
25
 
26
  logger = logging.get_logger(__name__)
27
 
28
-
29
  # Workaround for transformers < 4.36.0 check_imports issue
30
  # See: https://github.com/huggingface/transformers/issues/28459
31
  try:
@@ -92,7 +91,8 @@ class CastedLinearIn(nn.Linear):
92
  self.scale = nn.Parameter(torch.ones(in_features))
93
 
94
  def forward(self, x):
95
- return F.linear(x, (self.weight * (self.scale + 1.0).unsqueeze(0)).type_as(x), bias=self.bias.type_as(x) if self.bias is not None else None)
 
96
 
97
 
98
  class MultiCastedLinearOrthoIn(nn.Module):
@@ -114,7 +114,9 @@ class MultiCastedLinearOrthoIn(nn.Module):
114
  self.scale = nn.Parameter(torch.ones(in_features))
115
 
116
  def forward(self, x):
117
- return F.linear(x, (torch.cat([weight for weight in self.weights], dim=0) * (self.scale + 1.0).unsqueeze(0)).type_as(x), bias=self.bias.type_as(x) if self.bias is not None else None)
 
 
118
 
119
 
120
  class GeGLU(nn.Module):
@@ -128,7 +130,8 @@ class Embedding(nn.Module):
128
  super().__init__()
129
 
130
  self.word_embedding = nn.Embedding(config.vocab_size, config.hidden_size)
131
- self.word_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False, bias=False)
 
132
  self.word_scale = nn.Parameter(torch.zeros(config.hidden_size))
133
  self.dropout = nn.Dropout(config.embedding_dropout)
134
 
@@ -179,7 +182,9 @@ class Classifier(nn.Module):
179
 
180
 
181
  # from https://github.com/huggingface/transformers/blob/main/src/transformers/models/modernbert/modeling_modernbert.py
182
- def flash_attention_forward(qkv: torch.Tensor, rotary_emb: UnpaddedRotaryEmbedding, cu_seqlens: torch.Tensor, max_seqlen: int, causal: bool, local_attention: Tuple[int, int], dropout_p: float, deterministic: bool, target_dtype: torch.dtype = torch.bfloat16, **_kwargs):
 
 
183
  qkv = rotary_emb(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)
184
 
185
  convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16)
@@ -223,7 +228,8 @@ class ApplyRotaryEmbUnpad(torch.autograd.Function):
223
  # we get the same tensor
224
  # qk = rearrange(qkv[:, :2], "b_s t h d -> b_s (t h) d")
225
  qk = qkv[:, :2].view(total_nnz, -1, headdim)
226
- apply_rotary(qk, cos, sin, seqlen_offsets=0, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, interleaved=False, inplace=True)
 
227
 
228
  ctx.save_for_backward(cos, sin, cu_seqlens)
229
  ctx.max_seqlen = max_seqlen
@@ -263,7 +269,8 @@ class UnpaddedRotaryEmbedding(RotaryEmbedding):
263
  super().__init__(dim=dim, base=base, device=None, interleaved=False)
264
  self.max_seqlen = max_seqlen
265
 
266
- def forward(self, qkv: torch.Tensor, cu_seqlens: torch.Tensor, max_seqlen: Optional[int] = None) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
 
267
  if max_seqlen is not None:
268
  self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
269
 
@@ -351,11 +358,12 @@ class SelfAttention(nn.Module):
351
 
352
  self.qk_proj = MultiCastedLinearOrthoIn(self.hidden_size, [self.q_out_dim, self.k_out_dim], bias=False)
353
  self.v_proj = CastedLinearIn(self.hidden_size, self.v_out_dim, bias=False)
354
- self.out_proj = CastedLinearIn(self.d_v*self.num_attention_heads, self.hidden_size, bias=False)
355
 
356
  self.pre_v_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False)
357
  self.pre_qk_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False)
358
- self.inter_norm = nn.LayerNorm(self.d_v * self.num_attention_heads, eps=config.layer_norm_eps, elementwise_affine=False)
 
359
  self.q_norm = nn.LayerNorm(self.d_qk, eps=config.layer_norm_eps, elementwise_affine=False, bias=False)
360
  self.k_norm = nn.LayerNorm(self.d_qk, eps=config.layer_norm_eps, elementwise_affine=False, bias=False)
361
  self.k_scale = nn.Parameter(torch.ones(self.num_kv_heads, self.d_qk))
@@ -368,12 +376,13 @@ class SelfAttention(nn.Module):
368
 
369
  # Initialize rotary embeddings based on whether FlashAttention is available
370
  if flash_attn_varlen_qkvpacked_func is not None:
371
- self.rope_embedding = UnpaddedRotaryEmbedding(dim=self.d_qk, base=theta, max_seqlen=config.max_sequence_length)
 
372
  else:
373
  self.rope_embedding = RotaryPositionalEmbeddings(config, theta)
374
 
375
  self.scale = 1.0 / math.sqrt(self.d_qk)
376
- #self.lambdas = nn.Parameter(torch.tensor([0.5]))
377
 
378
  self.sequence_length = config.max_sequence_length
379
  self.is_causal = config.is_decoder
@@ -392,7 +401,8 @@ class SelfAttention(nn.Module):
392
  mask = mask.tril(diagonal=self.window_length).triu(diagonal=-self.window_length)
393
  return mask.view(1, 1, query_length, key_length)
394
 
395
- def attention_operation(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, padding_mask: Optional[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
 
396
  """Standard attention computation with masking."""
397
  batch_size, _, query_length, _ = query.size()
398
  _, _, key_length, _ = key.size()
@@ -405,7 +415,8 @@ class SelfAttention(nn.Module):
405
  else:
406
  attention_mask = window_mask
407
 
408
- attention_scores = torch.bmm(query.flatten(0, 1), key.transpose(-1, -2).flatten(0, 1)) * self.scale # shape: [B*H, Q_T, K_T]
 
409
  attention_scores = attention_scores.view(batch_size, self.num_attention_heads, query_length, key_length)
410
 
411
  attention_probabilities = MaskedSoftmax.apply(attention_scores, ~attention_mask, -1)
@@ -505,16 +516,17 @@ class SelfAttention(nn.Module):
505
  return output, v1
506
 
507
 
508
- class FeedForward(nn.Module):
509
  def __init__(self, config: GptBertConfig):
510
  super().__init__()
511
  self.pre_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False)
512
- self.up_proj = MultiCastedLinearOrthoIn(config.hidden_size, [config.intermediate_size, config.intermediate_size], bias=False)
 
513
  self.activation = GeGLU()
514
  self.inter_norm = nn.LayerNorm(config.intermediate_size, eps=config.layer_norm_eps, elementwise_affine=False)
515
  self.down_proj = CastedLinearIn(config.intermediate_size, config.hidden_size, bias=False)
516
  self.dropout = nn.Dropout(config.hidden_dropout)
517
-
518
  def forward(self, x: torch.Tensor):
519
  x = self.pre_norm(x.float()).type_as(x)
520
  x = self.up_proj(x)
@@ -559,10 +571,12 @@ class Layer(nn.Module):
559
  qk_layer = (lambdas_qk[0] * hidden_layer) + (lambdas_qk[1] * embeddings)
560
  attention_output, v1 = self.attention(v_layer, qk_layer, v1, padding_info)
561
 
562
- mlp_layer = (lambdas_mlp[0] * attention_output) + (lambdas_mlp[1] * hidden_layer) + (lambdas_mlp[2] * embeddings)
 
563
  mlp_layer = self.mlp(mlp_layer)
564
 
565
- output = (lambdas_out[0] * mlp_layer) + (lambdas_out[1] * attention_output) + (lambdas_out[2] * hidden_layer) + (lambdas_out[3] * embeddings)
 
566
 
567
  return output, v1
568
 
@@ -580,14 +594,16 @@ class Encoder(nn.Module):
580
  else:
581
  layer.set_window_length(config.local_window_length)
582
 
583
- def forward(self, hidden_layer: torch.Tensor, padding_info, output_hidden_states=False, checkpoint_activations=False):
 
584
  hidden_layers = [hidden_layer] if output_hidden_states else None
585
  v1 = None
586
  embeddings = hidden_layer
587
 
588
  for layer in self.layers:
589
  if checkpoint_activations:
590
- hidden_layer, v1 = torch.utils.checkpoint.checkpoint(layer, hidden_layer, embeddings, v1, padding_info, use_reentrant=True)
 
591
  else:
592
  hidden_layer, v1 = layer(hidden_layer, embeddings, v1, padding_info)
593
 
@@ -611,15 +627,19 @@ class GptBertPreTrainedModel(PreTrainedModel):
611
  def _init_weights(self, module):
612
  std = math.sqrt(2.0 / (5.0 * self.hidden_size))
613
 
614
- if isinstance(module, nn.Linear) or isinstance(module, CastedLinearIn):
615
- nn.init.trunc_normal_(module.weight.data, mean=0.0, std=std, a=-2*std, b=2*std)
616
- if module.bias is not None:
617
- module.bias.data.zero_()
618
- elif isinstance(module, nn.Embedding):
619
- nn.init.trunc_normal_(module.weight.data, mean=0.0, std=std, a=-2*std, b=2*std)
620
  elif isinstance(module, nn.LayerNorm):
 
 
 
 
621
  module.bias.data.zero_()
622
- module.weight.data.fill_(1.0)
 
623
 
624
 
625
  class GptBertModel(GptBertPreTrainedModel):
@@ -645,10 +665,10 @@ class GptBertModel(GptBertPreTrainedModel):
645
  self.embedding.word_embedding = value
646
 
647
  def get_contextualized_embeddings(
648
- self,
649
- input_ids: Optional[torch.Tensor] = None,
650
- attention_mask: Optional[torch.Tensor] = None,
651
- output_hidden_states: Optional[bool] = None
652
  ):
653
  if input_ids is not None:
654
  input_shape = input_ids.size()
@@ -697,24 +717,26 @@ class GptBertModel(GptBertPreTrainedModel):
697
  if flash_attn_varlen_qkvpacked_func is not None:
698
  last_layer = _pad_output(last_layer, indices, batch_size, seq_length)
699
  if output_hidden_states:
700
- contextualized_embeddings = [_pad_output(layer, indices, batch_size, seq_length) for layer in contextualized_embeddings]
 
701
  else:
702
  contextualized_embeddings = None
703
 
704
  return last_layer, contextualized_embeddings
705
 
706
  def forward(
707
- self,
708
- input_ids: Optional[torch.Tensor] = None,
709
- attention_mask: Optional[torch.Tensor] = None,
710
- output_hidden_states: Optional[bool] = None,
711
- output_attentions: Optional[bool] = None,
712
- return_dict: Optional[bool] = None,
713
- **kwargs
714
  ) -> Union[Tuple[torch.Tensor], BaseModelOutput]:
715
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
716
 
717
- sequence_output, contextualized_embeddings = self.get_contextualized_embeddings(input_ids, attention_mask, output_hidden_states)
 
718
 
719
  if not return_dict:
720
  return (
@@ -741,17 +763,18 @@ class GptBertForMaskedLM(GptBertModel):
741
  self.classifier.emb2vocab.weight = new_embeddings
742
 
743
  def forward(
744
- self,
745
- input_ids: Optional[torch.Tensor] = None,
746
- attention_mask: Optional[torch.Tensor] = None,
747
- output_hidden_states: Optional[bool] = None,
748
- return_dict: Optional[bool] = None,
749
- labels: Optional[torch.LongTensor] = None,
750
- **kwargs
751
  ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
752
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
753
 
754
- sequence_output, contextualized_embeddings = self.get_contextualized_embeddings(input_ids, attention_mask, output_hidden_states)
 
755
  subword_prediction = self.classifier(sequence_output)
756
  subword_prediction = 30 * torch.sigmoid(subword_prediction / 7.5)
757
 
@@ -761,7 +784,8 @@ class GptBertForMaskedLM(GptBertModel):
761
  subword_prediction_flatten = subword_prediction[:, :-1].flatten(0, 1)
762
  masked_lm_loss = F.cross_entropy(subword_prediction_flatten, labels_flatten)
763
 
764
- bos_logits = torch.zeros(subword_prediction.size(0), 1, self.config.vocab_size, dtype=subword_prediction.dtype, device=subword_prediction.device)
 
765
  bos_logits[:, :, self.config.bos_token_id] = 1.0
766
  subword_prediction = torch.cat([bos_logits, subword_prediction[:, :-1]], dim=1)
767
 
@@ -808,26 +832,27 @@ class GptBertForCausalLM(GptBertModel):
808
  return True
809
 
810
  def forward(
811
- self,
812
- input_ids: torch.LongTensor = None,
813
- attention_mask: Optional[torch.Tensor] = None,
814
- position_ids: Optional[torch.LongTensor] = None,
815
- token_type_ids: Optional[torch.Tensor] = None,
816
- past_key_values: Optional[torch.Tensor] = None,
817
- inputs_embeds: Optional[torch.FloatTensor] = None,
818
- labels: Optional[torch.LongTensor] = None,
819
- use_cache: Optional[bool] = None,
820
- cache_position: Optional[torch.LongTensor] = None,
821
- output_attentions: Optional[bool] = None,
822
- output_hidden_states: Optional[bool] = None,
823
- return_dict: Optional[bool] = None
824
  ) -> Union[Tuple, CausalLMOutput]:
825
 
826
  assert inputs_embeds is None, "inputs_embeds is not supported for now"
827
  assert past_key_values is None, "past_key_values is not supported for now"
828
  assert not use_cache, "use_cache is not supported for now"
829
 
830
- sequence_output, contextualized_embeddings = self.get_contextualized_embeddings(input_ids, attention_mask, output_hidden_states)
 
831
  subword_prediction = self.classifier(sequence_output)
832
  subword_prediction = 30 * torch.sigmoid(subword_prediction / 7.5)
833
 
@@ -837,13 +862,6 @@ class GptBertForCausalLM(GptBertModel):
837
  subword_prediction_flatten = subword_prediction[:, :-1].flatten(0, 1)
838
  causal_lm_loss = F.cross_entropy(subword_prediction_flatten, labels_flatten)
839
 
840
- if not return_dict:
841
- output = (
842
- subword_prediction,
843
- *([contextualized_embeddings] if output_hidden_states else [])
844
- )
845
- return ((causal_lm_loss,) + output) if masked_lm_loss is not None else output
846
-
847
  return CausalLMOutput(
848
  loss=causal_lm_loss,
849
  logits=subword_prediction,
@@ -851,23 +869,23 @@ class GptBertForCausalLM(GptBertModel):
851
  )
852
 
853
  def prepare_inputs_for_generation(
854
- self,
855
- input_ids: torch.Tensor,
856
- past_key_values: Optional[torch.Tensor] = None,
857
- attention_mask: Optional[torch.Tensor] = None,
858
- inputs_embeds: Optional[torch.Tensor] = None,
859
- cache_position: Optional[torch.LongTensor] = None,
860
- position_ids: Optional[torch.LongTensor] = None,
861
- use_cache: bool = True,
862
- num_logits_to_keep: Optional[int] = None,
863
- **kwargs,
864
  ):
865
  # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
866
  # Exception 1: when passing input_embeds, input_ids may be missing entries
867
  # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
868
  if past_key_values is not None:
869
  if inputs_embeds is not None: # Exception 1
870
- input_ids = input_ids[:, -cache_position.shape[0] :]
871
  elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
872
  input_ids = input_ids[:, cache_position]
873
 
@@ -876,7 +894,7 @@ class GptBertForCausalLM(GptBertModel):
876
  position_ids = attention_mask.long().cumsum(-1) - 1
877
  position_ids.masked_fill_(attention_mask == 0, 1)
878
  if past_key_values:
879
- position_ids = position_ids[:, -input_ids.shape[1] :]
880
 
881
  # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
882
  position_ids = position_ids.clone(memory_format=torch.contiguous_format)
@@ -914,17 +932,18 @@ class GptBertForSequenceClassification(GptBertModel):
914
  self.post_init()
915
 
916
  def forward(
917
- self,
918
- input_ids: Optional[torch.Tensor] = None,
919
- attention_mask: Optional[torch.Tensor] = None,
920
- output_hidden_states: Optional[bool] = None,
921
- return_dict: Optional[bool] = None,
922
- labels: Optional[torch.LongTensor] = None,
923
- **kwargs
924
  ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
925
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
926
 
927
- sequence_output, contextualized_embeddings = self.get_contextualized_embeddings(input_ids, attention_mask, output_hidden_states)
 
928
  logits = self.classifier(sequence_output[:, 0, :])
929
 
930
  loss = None
@@ -950,13 +969,6 @@ class GptBertForSequenceClassification(GptBertModel):
950
  loss_fct = nn.BCEWithLogitsLoss()
951
  loss = loss_fct(logits, labels)
952
 
953
- if not return_dict:
954
- output = (
955
- logits,
956
- *([contextualized_embeddings] if output_hidden_states else [])
957
- )
958
- return ((loss,) + output) if loss is not None else output
959
-
960
  return SequenceClassifierOutput(
961
  loss=loss,
962
  logits=logits,
@@ -976,17 +988,18 @@ class GptBertForTokenClassification(GptBertModel):
976
  self.post_init()
977
 
978
  def forward(
979
- self,
980
- input_ids: Optional[torch.Tensor] = None,
981
- attention_mask: Optional[torch.Tensor] = None,
982
- output_hidden_states: Optional[bool] = None,
983
- return_dict: Optional[bool] = None,
984
- labels: Optional[torch.LongTensor] = None,
985
- **kwargs
986
  ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
987
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
988
 
989
- sequence_output, contextualized_embeddings = self.get_contextualized_embeddings(input_ids, attention_mask, output_hidden_states)
 
990
  logits = self.classifier(sequence_output)
991
 
992
  loss = None
@@ -994,19 +1007,10 @@ class GptBertForTokenClassification(GptBertModel):
994
  loss_fct = nn.CrossEntropyLoss()
995
  loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
996
 
997
- if not return_dict:
998
- output = (
999
- logits,
1000
- *([contextualized_embeddings] if output_hidden_states else []),
1001
- *([attention_probs] if output_attentions else [])
1002
- )
1003
- return ((loss,) + output) if loss is not None else output
1004
-
1005
  return TokenClassifierOutput(
1006
  loss=loss,
1007
  logits=logits,
1008
  hidden_states=contextualized_embeddings if output_hidden_states else None,
1009
- attentions=attention_probs if output_attentions else None
1010
  )
1011
 
1012
 
@@ -1022,18 +1026,19 @@ class GptBertForQuestionAnswering(GptBertModel):
1022
  self.post_init()
1023
 
1024
  def forward(
1025
- self,
1026
- input_ids: Optional[torch.Tensor] = None,
1027
- attention_mask: Optional[torch.Tensor] = None,
1028
- output_hidden_states: Optional[bool] = None,
1029
- return_dict: Optional[bool] = None,
1030
- start_positions: Optional[torch.Tensor] = None,
1031
- end_positions: Optional[torch.Tensor] = None,
1032
- **kwargs
1033
  ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
1034
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1035
 
1036
- sequence_output, contextualized_embeddings = self.get_contextualized_embeddings(input_ids, attention_mask, output_hidden_states)
 
1037
  logits = self.classifier(sequence_output)
1038
 
1039
  start_logits, end_logits = logits.split(1, dim=-1)
@@ -1058,14 +1063,6 @@ class GptBertForQuestionAnswering(GptBertModel):
1058
  end_loss = loss_fct(end_logits, end_positions)
1059
  total_loss = (start_loss + end_loss) / 2
1060
 
1061
- if not return_dict:
1062
- output = (
1063
- start_logits,
1064
- end_logits,
1065
- *([contextualized_embeddings] if output_hidden_states else [])
1066
- )
1067
- return ((total_loss,) + output) if total_loss is not None else output
1068
-
1069
  return QuestionAnsweringModelOutput(
1070
  loss=total_loss,
1071
  start_logits=start_logits,
@@ -1086,13 +1083,13 @@ class GptBertForMultipleChoice(GptBertModel):
1086
  self.post_init()
1087
 
1088
  def forward(
1089
- self,
1090
- input_ids: Optional[torch.Tensor] = None,
1091
- attention_mask: Optional[torch.Tensor] = None,
1092
- labels: Optional[torch.Tensor] = None,
1093
- output_hidden_states: Optional[bool] = None,
1094
- return_dict: Optional[bool] = None,
1095
- **kwargs
1096
  ) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]:
1097
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1098
  num_choices = input_ids.shape[1]
@@ -1100,7 +1097,9 @@ class GptBertForMultipleChoice(GptBertModel):
1100
  flat_input_ids = input_ids.view(-1, input_ids.size(-1))
1101
  flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
1102
 
1103
- sequence_output, contextualized_embeddings = self.get_contextualized_embeddings(flat_input_ids, flat_attention_mask, output_hidden_states)
 
 
1104
  logits = self.classifier(sequence_output)
1105
  reshaped_logits = logits.view(-1, num_choices)
1106
 
@@ -1109,13 +1108,6 @@ class GptBertForMultipleChoice(GptBertModel):
1109
  loss_fct = nn.CrossEntropyLoss()
1110
  loss = loss_fct(reshaped_logits, labels)
1111
 
1112
- if not return_dict:
1113
- output = (
1114
- reshaped_logits,
1115
- *([contextualized_embeddings] if output_hidden_states else [])
1116
- )
1117
- return ((loss,) + output) if loss is not None else output
1118
-
1119
  return MultipleChoiceModelOutput(
1120
  loss=loss,
1121
  logits=reshaped_logits,
 
25
 
26
  logger = logging.get_logger(__name__)
27
 
 
28
  # Workaround for transformers < 4.36.0 check_imports issue
29
  # See: https://github.com/huggingface/transformers/issues/28459
30
  try:
 
91
  self.scale = nn.Parameter(torch.ones(in_features))
92
 
93
  def forward(self, x):
94
+ return F.linear(x, (self.weight * (self.scale + 1.0).unsqueeze(0)).type_as(x),
95
+ bias=self.bias.type_as(x) if self.bias is not None else None)
96
 
97
 
98
  class MultiCastedLinearOrthoIn(nn.Module):
 
114
  self.scale = nn.Parameter(torch.ones(in_features))
115
 
116
  def forward(self, x):
117
+ return F.linear(x, (
118
+ torch.cat([weight for weight in self.weights], dim=0) * (self.scale + 1.0).unsqueeze(0)).type_as(x),
119
+ bias=self.bias.type_as(x) if self.bias is not None else None)
120
 
121
 
122
  class GeGLU(nn.Module):
 
130
  super().__init__()
131
 
132
  self.word_embedding = nn.Embedding(config.vocab_size, config.hidden_size)
133
+ self.word_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False,
134
+ bias=False)
135
  self.word_scale = nn.Parameter(torch.zeros(config.hidden_size))
136
  self.dropout = nn.Dropout(config.embedding_dropout)
137
 
 
182
 
183
 
184
  # from https://github.com/huggingface/transformers/blob/main/src/transformers/models/modernbert/modeling_modernbert.py
185
+ def flash_attention_forward(qkv: torch.Tensor, rotary_emb: UnpaddedRotaryEmbedding, cu_seqlens: torch.Tensor,
186
+ max_seqlen: int, causal: bool, local_attention: Tuple[int, int], dropout_p: float,
187
+ deterministic: bool, target_dtype: torch.dtype = torch.bfloat16, **_kwargs):
188
  qkv = rotary_emb(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)
189
 
190
  convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16)
 
228
  # we get the same tensor
229
  # qk = rearrange(qkv[:, :2], "b_s t h d -> b_s (t h) d")
230
  qk = qkv[:, :2].view(total_nnz, -1, headdim)
231
+ apply_rotary(qk, cos, sin, seqlen_offsets=0, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, interleaved=False,
232
+ inplace=True)
233
 
234
  ctx.save_for_backward(cos, sin, cu_seqlens)
235
  ctx.max_seqlen = max_seqlen
 
269
  super().__init__(dim=dim, base=base, device=None, interleaved=False)
270
  self.max_seqlen = max_seqlen
271
 
272
+ def forward(self, qkv: torch.Tensor, cu_seqlens: torch.Tensor, max_seqlen: Optional[int] = None) -> Union[
273
+ torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
274
  if max_seqlen is not None:
275
  self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
276
 
 
358
 
359
  self.qk_proj = MultiCastedLinearOrthoIn(self.hidden_size, [self.q_out_dim, self.k_out_dim], bias=False)
360
  self.v_proj = CastedLinearIn(self.hidden_size, self.v_out_dim, bias=False)
361
+ self.out_proj = CastedLinearIn(self.d_v * self.num_attention_heads, self.hidden_size, bias=False)
362
 
363
  self.pre_v_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False)
364
  self.pre_qk_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False)
365
+ self.inter_norm = nn.LayerNorm(self.d_v * self.num_attention_heads, eps=config.layer_norm_eps,
366
+ elementwise_affine=False)
367
  self.q_norm = nn.LayerNorm(self.d_qk, eps=config.layer_norm_eps, elementwise_affine=False, bias=False)
368
  self.k_norm = nn.LayerNorm(self.d_qk, eps=config.layer_norm_eps, elementwise_affine=False, bias=False)
369
  self.k_scale = nn.Parameter(torch.ones(self.num_kv_heads, self.d_qk))
 
376
 
377
  # Initialize rotary embeddings based on whether FlashAttention is available
378
  if flash_attn_varlen_qkvpacked_func is not None:
379
+ self.rope_embedding = UnpaddedRotaryEmbedding(dim=self.d_qk, base=theta,
380
+ max_seqlen=config.max_sequence_length)
381
  else:
382
  self.rope_embedding = RotaryPositionalEmbeddings(config, theta)
383
 
384
  self.scale = 1.0 / math.sqrt(self.d_qk)
385
+ # self.lambdas = nn.Parameter(torch.tensor([0.5]))
386
 
387
  self.sequence_length = config.max_sequence_length
388
  self.is_causal = config.is_decoder
 
401
  mask = mask.tril(diagonal=self.window_length).triu(diagonal=-self.window_length)
402
  return mask.view(1, 1, query_length, key_length)
403
 
404
+ def attention_operation(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
405
+ padding_mask: Optional[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
406
  """Standard attention computation with masking."""
407
  batch_size, _, query_length, _ = query.size()
408
  _, _, key_length, _ = key.size()
 
415
  else:
416
  attention_mask = window_mask
417
 
418
+ attention_scores = torch.bmm(query.flatten(0, 1),
419
+ key.transpose(-1, -2).flatten(0, 1)) * self.scale # shape: [B*H, Q_T, K_T]
420
  attention_scores = attention_scores.view(batch_size, self.num_attention_heads, query_length, key_length)
421
 
422
  attention_probabilities = MaskedSoftmax.apply(attention_scores, ~attention_mask, -1)
 
516
  return output, v1
517
 
518
 
519
+ class FeedForward(nn.Module):
520
  def __init__(self, config: GptBertConfig):
521
  super().__init__()
522
  self.pre_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False)
523
+ self.up_proj = MultiCastedLinearOrthoIn(config.hidden_size,
524
+ [config.intermediate_size, config.intermediate_size], bias=False)
525
  self.activation = GeGLU()
526
  self.inter_norm = nn.LayerNorm(config.intermediate_size, eps=config.layer_norm_eps, elementwise_affine=False)
527
  self.down_proj = CastedLinearIn(config.intermediate_size, config.hidden_size, bias=False)
528
  self.dropout = nn.Dropout(config.hidden_dropout)
529
+
530
  def forward(self, x: torch.Tensor):
531
  x = self.pre_norm(x.float()).type_as(x)
532
  x = self.up_proj(x)
 
571
  qk_layer = (lambdas_qk[0] * hidden_layer) + (lambdas_qk[1] * embeddings)
572
  attention_output, v1 = self.attention(v_layer, qk_layer, v1, padding_info)
573
 
574
+ mlp_layer = (lambdas_mlp[0] * attention_output) + (lambdas_mlp[1] * hidden_layer) + (
575
+ lambdas_mlp[2] * embeddings)
576
  mlp_layer = self.mlp(mlp_layer)
577
 
578
+ output = (lambdas_out[0] * mlp_layer) + (lambdas_out[1] * attention_output) + (
579
+ lambdas_out[2] * hidden_layer) + (lambdas_out[3] * embeddings)
580
 
581
  return output, v1
582
 
 
594
  else:
595
  layer.set_window_length(config.local_window_length)
596
 
597
+ def forward(self, hidden_layer: torch.Tensor, padding_info, output_hidden_states=False,
598
+ checkpoint_activations=False):
599
  hidden_layers = [hidden_layer] if output_hidden_states else None
600
  v1 = None
601
  embeddings = hidden_layer
602
 
603
  for layer in self.layers:
604
  if checkpoint_activations:
605
+ hidden_layer, v1 = torch.utils.checkpoint.checkpoint(layer, hidden_layer, embeddings, v1, padding_info,
606
+ use_reentrant=True)
607
  else:
608
  hidden_layer, v1 = layer(hidden_layer, embeddings, v1, padding_info)
609
 
 
627
  def _init_weights(self, module):
628
  std = math.sqrt(2.0 / (5.0 * self.hidden_size))
629
 
630
+ if isinstance(module, MultiCastedLinearOrthoIn):
631
+ for weight in module.weights:
632
+ nn.init.trunc_normal_(weight.data, mean=0.0, std=std, a=-2 * std, b=2 * std)
633
+ elif isinstance(module, (nn.Linear, nn.Embedding)):
634
+ nn.init.trunc_normal_(module.weight.data, mean=0.0, std=std, a=-2 * std, b=2 * std)
 
635
  elif isinstance(module, nn.LayerNorm):
636
+ if module.weight is not None:
637
+ module.weight.data.fill_(1.0)
638
+
639
+ if hasattr(module, 'bias') and module.bias is not None:
640
  module.bias.data.zero_()
641
+ if hasattr(module, 'scale') and isinstance(module.scale, nn.Parameter):
642
+ module.scale.data.fill_(1.0)
643
 
644
 
645
  class GptBertModel(GptBertPreTrainedModel):
 
665
  self.embedding.word_embedding = value
666
 
667
  def get_contextualized_embeddings(
668
+ self,
669
+ input_ids: Optional[torch.Tensor] = None,
670
+ attention_mask: Optional[torch.Tensor] = None,
671
+ output_hidden_states: Optional[bool] = None
672
  ):
673
  if input_ids is not None:
674
  input_shape = input_ids.size()
 
717
  if flash_attn_varlen_qkvpacked_func is not None:
718
  last_layer = _pad_output(last_layer, indices, batch_size, seq_length)
719
  if output_hidden_states:
720
+ contextualized_embeddings = [_pad_output(layer, indices, batch_size, seq_length) for layer in
721
+ contextualized_embeddings]
722
  else:
723
  contextualized_embeddings = None
724
 
725
  return last_layer, contextualized_embeddings
726
 
727
  def forward(
728
+ self,
729
+ input_ids: Optional[torch.Tensor] = None,
730
+ attention_mask: Optional[torch.Tensor] = None,
731
+ output_hidden_states: Optional[bool] = None,
732
+ output_attentions: Optional[bool] = None,
733
+ return_dict: Optional[bool] = None,
734
+ **kwargs
735
  ) -> Union[Tuple[torch.Tensor], BaseModelOutput]:
736
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
737
 
738
+ sequence_output, contextualized_embeddings = self.get_contextualized_embeddings(input_ids, attention_mask,
739
+ output_hidden_states)
740
 
741
  if not return_dict:
742
  return (
 
763
  self.classifier.emb2vocab.weight = new_embeddings
764
 
765
  def forward(
766
+ self,
767
+ input_ids: Optional[torch.Tensor] = None,
768
+ attention_mask: Optional[torch.Tensor] = None,
769
+ output_hidden_states: Optional[bool] = None,
770
+ return_dict: Optional[bool] = None,
771
+ labels: Optional[torch.LongTensor] = None,
772
+ **kwargs
773
  ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
774
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
775
 
776
+ sequence_output, contextualized_embeddings = self.get_contextualized_embeddings(input_ids, attention_mask,
777
+ output_hidden_states)
778
  subword_prediction = self.classifier(sequence_output)
779
  subword_prediction = 30 * torch.sigmoid(subword_prediction / 7.5)
780
 
 
784
  subword_prediction_flatten = subword_prediction[:, :-1].flatten(0, 1)
785
  masked_lm_loss = F.cross_entropy(subword_prediction_flatten, labels_flatten)
786
 
787
+ bos_logits = torch.zeros(subword_prediction.size(0), 1, self.config.vocab_size, dtype=subword_prediction.dtype,
788
+ device=subword_prediction.device)
789
  bos_logits[:, :, self.config.bos_token_id] = 1.0
790
  subword_prediction = torch.cat([bos_logits, subword_prediction[:, :-1]], dim=1)
791
 
 
832
  return True
833
 
834
  def forward(
835
+ self,
836
+ input_ids: torch.LongTensor = None,
837
+ attention_mask: Optional[torch.Tensor] = None,
838
+ position_ids: Optional[torch.LongTensor] = None,
839
+ token_type_ids: Optional[torch.Tensor] = None,
840
+ past_key_values: Optional[torch.Tensor] = None,
841
+ inputs_embeds: Optional[torch.FloatTensor] = None,
842
+ labels: Optional[torch.LongTensor] = None,
843
+ use_cache: Optional[bool] = None,
844
+ cache_position: Optional[torch.LongTensor] = None,
845
+ output_attentions: Optional[bool] = None,
846
+ output_hidden_states: Optional[bool] = None,
847
+ return_dict: Optional[bool] = None
848
  ) -> Union[Tuple, CausalLMOutput]:
849
 
850
  assert inputs_embeds is None, "inputs_embeds is not supported for now"
851
  assert past_key_values is None, "past_key_values is not supported for now"
852
  assert not use_cache, "use_cache is not supported for now"
853
 
854
+ sequence_output, contextualized_embeddings = self.get_contextualized_embeddings(input_ids, attention_mask,
855
+ output_hidden_states)
856
  subword_prediction = self.classifier(sequence_output)
857
  subword_prediction = 30 * torch.sigmoid(subword_prediction / 7.5)
858
 
 
862
  subword_prediction_flatten = subword_prediction[:, :-1].flatten(0, 1)
863
  causal_lm_loss = F.cross_entropy(subword_prediction_flatten, labels_flatten)
864
 
 
 
 
 
 
 
 
865
  return CausalLMOutput(
866
  loss=causal_lm_loss,
867
  logits=subword_prediction,
 
869
  )
870
 
871
  def prepare_inputs_for_generation(
872
+ self,
873
+ input_ids: torch.Tensor,
874
+ past_key_values: Optional[torch.Tensor] = None,
875
+ attention_mask: Optional[torch.Tensor] = None,
876
+ inputs_embeds: Optional[torch.Tensor] = None,
877
+ cache_position: Optional[torch.LongTensor] = None,
878
+ position_ids: Optional[torch.LongTensor] = None,
879
+ use_cache: bool = True,
880
+ num_logits_to_keep: Optional[int] = None,
881
+ **kwargs,
882
  ):
883
  # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
884
  # Exception 1: when passing input_embeds, input_ids may be missing entries
885
  # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
886
  if past_key_values is not None:
887
  if inputs_embeds is not None: # Exception 1
888
+ input_ids = input_ids[:, -cache_position.shape[0]:]
889
  elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
890
  input_ids = input_ids[:, cache_position]
891
 
 
894
  position_ids = attention_mask.long().cumsum(-1) - 1
895
  position_ids.masked_fill_(attention_mask == 0, 1)
896
  if past_key_values:
897
+ position_ids = position_ids[:, -input_ids.shape[1]:]
898
 
899
  # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
900
  position_ids = position_ids.clone(memory_format=torch.contiguous_format)
 
932
  self.post_init()
933
 
934
  def forward(
935
+ self,
936
+ input_ids: Optional[torch.Tensor] = None,
937
+ attention_mask: Optional[torch.Tensor] = None,
938
+ output_hidden_states: Optional[bool] = None,
939
+ return_dict: Optional[bool] = None,
940
+ labels: Optional[torch.LongTensor] = None,
941
+ **kwargs
942
  ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
943
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
944
 
945
+ sequence_output, contextualized_embeddings = self.get_contextualized_embeddings(input_ids, attention_mask,
946
+ output_hidden_states)
947
  logits = self.classifier(sequence_output[:, 0, :])
948
 
949
  loss = None
 
969
  loss_fct = nn.BCEWithLogitsLoss()
970
  loss = loss_fct(logits, labels)
971
 
 
 
 
 
 
 
 
972
  return SequenceClassifierOutput(
973
  loss=loss,
974
  logits=logits,
 
988
  self.post_init()
989
 
990
  def forward(
991
+ self,
992
+ input_ids: Optional[torch.Tensor] = None,
993
+ attention_mask: Optional[torch.Tensor] = None,
994
+ output_hidden_states: Optional[bool] = None,
995
+ return_dict: Optional[bool] = None,
996
+ labels: Optional[torch.LongTensor] = None,
997
+ **kwargs
998
  ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
999
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1000
 
1001
+ sequence_output, contextualized_embeddings = self.get_contextualized_embeddings(input_ids, attention_mask,
1002
+ output_hidden_states)
1003
  logits = self.classifier(sequence_output)
1004
 
1005
  loss = None
 
1007
  loss_fct = nn.CrossEntropyLoss()
1008
  loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1009
 
 
 
 
 
 
 
 
 
1010
  return TokenClassifierOutput(
1011
  loss=loss,
1012
  logits=logits,
1013
  hidden_states=contextualized_embeddings if output_hidden_states else None,
 
1014
  )
1015
 
1016
 
 
1026
  self.post_init()
1027
 
1028
  def forward(
1029
+ self,
1030
+ input_ids: Optional[torch.Tensor] = None,
1031
+ attention_mask: Optional[torch.Tensor] = None,
1032
+ output_hidden_states: Optional[bool] = None,
1033
+ return_dict: Optional[bool] = None,
1034
+ start_positions: Optional[torch.Tensor] = None,
1035
+ end_positions: Optional[torch.Tensor] = None,
1036
+ **kwargs
1037
  ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
1038
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1039
 
1040
+ sequence_output, contextualized_embeddings = self.get_contextualized_embeddings(input_ids, attention_mask,
1041
+ output_hidden_states)
1042
  logits = self.classifier(sequence_output)
1043
 
1044
  start_logits, end_logits = logits.split(1, dim=-1)
 
1063
  end_loss = loss_fct(end_logits, end_positions)
1064
  total_loss = (start_loss + end_loss) / 2
1065
 
 
 
 
 
 
 
 
 
1066
  return QuestionAnsweringModelOutput(
1067
  loss=total_loss,
1068
  start_logits=start_logits,
 
1083
  self.post_init()
1084
 
1085
  def forward(
1086
+ self,
1087
+ input_ids: Optional[torch.Tensor] = None,
1088
+ attention_mask: Optional[torch.Tensor] = None,
1089
+ labels: Optional[torch.Tensor] = None,
1090
+ output_hidden_states: Optional[bool] = None,
1091
+ return_dict: Optional[bool] = None,
1092
+ **kwargs
1093
  ) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]:
1094
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1095
  num_choices = input_ids.shape[1]
 
1097
  flat_input_ids = input_ids.view(-1, input_ids.size(-1))
1098
  flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
1099
 
1100
+ sequence_output, contextualized_embeddings = self.get_contextualized_embeddings(flat_input_ids,
1101
+ flat_attention_mask,
1102
+ output_hidden_states)
1103
  logits = self.classifier(sequence_output)
1104
  reshaped_logits = logits.view(-1, num_choices)
1105
 
 
1108
  loss_fct = nn.CrossEntropyLoss()
1109
  loss = loss_fct(reshaped_logits, labels)
1110
 
 
 
 
 
 
 
 
1111
  return MultipleChoiceModelOutput(
1112
  loss=loss,
1113
  logits=reshaped_logits,