Recag commited on
Commit
da5cf53
·
1 Parent(s): 053a765

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +14 -25
model.py CHANGED
@@ -1,5 +1,5 @@
1
  # coding=utf-8
2
- # Copyright 2023 BharatTech Tech Ecosystem Pvt. All rights reserved.
3
 
4
  """ PyTorch Bharatai model."""
5
  import math
@@ -16,7 +16,7 @@ from transformers.activations import ACT2FN
16
  from transformers.modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_causal_attention_mask
17
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
18
  from transformers.modeling_utils import PreTrainedModel
19
- from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
20
  from transformers.utils import (
21
  add_start_docstrings,
22
  add_start_docstrings_to_model_forward,
@@ -36,8 +36,8 @@ if is_flash_attn_2_available():
36
  # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
37
  # It means that the function will not be traced through and simply appear as a node in the graph.
38
  if is_torch_fx_available():
39
-
40
- import torch.fx
41
 
42
  _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
43
 
@@ -98,7 +98,7 @@ ALL_LAYERNORM_LAYERS.append(BharataiRMSNorm)
98
 
99
 
100
  class BharataiRotaryEmbedding(nn.Module):
101
- def __init__(self, dim, max_position_embeddings=16384, base=10000, device=None):
102
  super().__init__()
103
 
104
  self.dim = dim
@@ -136,7 +136,7 @@ class BharataiRotaryEmbedding(nn.Module):
136
  class BharataiLinearScalingRotaryEmbedding(BharataiRotaryEmbedding):
137
  """BharataiRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
138
 
139
- def __init__(self, dim, max_position_embeddings=16384, base=10000, device=None, scaling_factor=1.0):
140
  self.scaling_factor = scaling_factor
141
  super().__init__(dim, max_position_embeddings, base, device)
142
 
@@ -155,7 +155,7 @@ class BharataiLinearScalingRotaryEmbedding(BharataiRotaryEmbedding):
155
  class BharataiDynamicNTKScalingRotaryEmbedding(BharataiRotaryEmbedding):
156
  """BharataiRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
157
 
158
- def __init__(self, dim, max_position_embeddings=16384, base=10000, device=None, scaling_factor=1.0):
159
  self.scaling_factor = scaling_factor
160
  super().__init__(dim, max_position_embeddings, base, device)
161
 
@@ -896,24 +896,13 @@ class BharataiModel(BharataiPreTrainedModel):
896
  past_key_value = past_key_values[idx] if past_key_values is not None else None
897
 
898
  if self.gradient_checkpointing and self.training:
899
- layer_outputs = self._gradient_checkpointing_func(
900
- decoder_layer.__call__,
901
- hidden_states,
902
- attention_mask,
903
- position_ids,
904
- past_key_value,
905
- output_attentions,
906
- use_cache,
907
- )
908
  else:
909
- layer_outputs = decoder_layer(
910
- hidden_states,
911
- attention_mask=attention_mask,
912
- position_ids=position_ids,
913
- past_key_value=past_key_value,
914
- output_attentions=output_attentions,
915
- use_cache=use_cache,
916
- )
917
 
918
  hidden_states = layer_outputs[0]
919
 
@@ -1218,4 +1207,4 @@ class BharataiForSequenceClassification(BharataiPreTrainedModel):
1218
  past_key_values=transformer_outputs.past_key_values,
1219
  hidden_states=transformer_outputs.hidden_states,
1220
  attentions=transformer_outputs.attentions,
1221
- )
 
1
  # coding=utf-8
2
+ # Copyright 2023 BharatTech Tech Ecosystem Pvt. Ltd.All rights reserved.
3
 
4
  """ PyTorch Bharatai model."""
5
  import math
 
16
  from transformers.modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_causal_attention_mask
17
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
18
  from transformers.modeling_utils import PreTrainedModel
19
+ from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13
20
  from transformers.utils import (
21
  add_start_docstrings,
22
  add_start_docstrings_to_model_forward,
 
36
  # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
37
  # It means that the function will not be traced through and simply appear as a node in the graph.
38
  if is_torch_fx_available():
39
+ if not is_torch_greater_or_equal_than_1_13:
40
+ import torch.fx
41
 
42
  _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
43
 
 
98
 
99
 
100
  class BharataiRotaryEmbedding(nn.Module):
101
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
102
  super().__init__()
103
 
104
  self.dim = dim
 
136
  class BharataiLinearScalingRotaryEmbedding(BharataiRotaryEmbedding):
137
  """BharataiRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
138
 
139
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
140
  self.scaling_factor = scaling_factor
141
  super().__init__(dim, max_position_embeddings, base, device)
142
 
 
155
  class BharataiDynamicNTKScalingRotaryEmbedding(BharataiRotaryEmbedding):
156
  """BharataiRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
157
 
158
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
159
  self.scaling_factor = scaling_factor
160
  super().__init__(dim, max_position_embeddings, base, device)
161
 
 
896
  past_key_value = past_key_values[idx] if past_key_values is not None else None
897
 
898
  if self.gradient_checkpointing and self.training:
899
+ if output_attentions:
900
+ layer_outputs = self._gradient_checkpointing_func(decoder_layer.__call__, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache)
901
+ else:
902
+ layer_outputs = self._gradient_checkpointing_func(decoder_layer.__call__, hidden_states, attention_mask, position_ids, past_key_value, None, use_cache)
 
 
 
 
 
903
  else:
904
+ layer_outputs = decoder_layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache)
905
+
 
 
 
 
 
 
906
 
907
  hidden_states = layer_outputs[0]
908
 
 
1207
  past_key_values=transformer_outputs.past_key_values,
1208
  hidden_states=transformer_outputs.hidden_states,
1209
  attentions=transformer_outputs.attentions,
1210
+ )