Update model.py
Browse files
model.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
# coding=utf-8
|
| 2 |
-
# Copyright 2023 BharatTech Tech Ecosystem Pvt.
|
| 3 |
|
| 4 |
""" PyTorch Bharatai model."""
|
| 5 |
import math
|
|
@@ -16,7 +16,7 @@ from transformers.activations import ACT2FN
|
|
| 16 |
from transformers.modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_causal_attention_mask
|
| 17 |
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
| 18 |
from transformers.modeling_utils import PreTrainedModel
|
| 19 |
-
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
|
| 20 |
from transformers.utils import (
|
| 21 |
add_start_docstrings,
|
| 22 |
add_start_docstrings_to_model_forward,
|
|
@@ -36,8 +36,8 @@ if is_flash_attn_2_available():
|
|
| 36 |
# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
|
| 37 |
# It means that the function will not be traced through and simply appear as a node in the graph.
|
| 38 |
if is_torch_fx_available():
|
| 39 |
-
|
| 40 |
-
|
| 41 |
|
| 42 |
_prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
|
| 43 |
|
|
@@ -98,7 +98,7 @@ ALL_LAYERNORM_LAYERS.append(BharataiRMSNorm)
|
|
| 98 |
|
| 99 |
|
| 100 |
class BharataiRotaryEmbedding(nn.Module):
|
| 101 |
-
def __init__(self, dim, max_position_embeddings=
|
| 102 |
super().__init__()
|
| 103 |
|
| 104 |
self.dim = dim
|
|
@@ -136,7 +136,7 @@ class BharataiRotaryEmbedding(nn.Module):
|
|
| 136 |
class BharataiLinearScalingRotaryEmbedding(BharataiRotaryEmbedding):
|
| 137 |
"""BharataiRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
| 138 |
|
| 139 |
-
def __init__(self, dim, max_position_embeddings=
|
| 140 |
self.scaling_factor = scaling_factor
|
| 141 |
super().__init__(dim, max_position_embeddings, base, device)
|
| 142 |
|
|
@@ -155,7 +155,7 @@ class BharataiLinearScalingRotaryEmbedding(BharataiRotaryEmbedding):
|
|
| 155 |
class BharataiDynamicNTKScalingRotaryEmbedding(BharataiRotaryEmbedding):
|
| 156 |
"""BharataiRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
| 157 |
|
| 158 |
-
def __init__(self, dim, max_position_embeddings=
|
| 159 |
self.scaling_factor = scaling_factor
|
| 160 |
super().__init__(dim, max_position_embeddings, base, device)
|
| 161 |
|
|
@@ -896,24 +896,13 @@ class BharataiModel(BharataiPreTrainedModel):
|
|
| 896 |
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
| 897 |
|
| 898 |
if self.gradient_checkpointing and self.training:
|
| 899 |
-
|
| 900 |
-
decoder_layer.__call__,
|
| 901 |
-
|
| 902 |
-
attention_mask,
|
| 903 |
-
position_ids,
|
| 904 |
-
past_key_value,
|
| 905 |
-
output_attentions,
|
| 906 |
-
use_cache,
|
| 907 |
-
)
|
| 908 |
else:
|
| 909 |
-
layer_outputs = decoder_layer(
|
| 910 |
-
|
| 911 |
-
attention_mask=attention_mask,
|
| 912 |
-
position_ids=position_ids,
|
| 913 |
-
past_key_value=past_key_value,
|
| 914 |
-
output_attentions=output_attentions,
|
| 915 |
-
use_cache=use_cache,
|
| 916 |
-
)
|
| 917 |
|
| 918 |
hidden_states = layer_outputs[0]
|
| 919 |
|
|
@@ -1218,4 +1207,4 @@ class BharataiForSequenceClassification(BharataiPreTrainedModel):
|
|
| 1218 |
past_key_values=transformer_outputs.past_key_values,
|
| 1219 |
hidden_states=transformer_outputs.hidden_states,
|
| 1220 |
attentions=transformer_outputs.attentions,
|
| 1221 |
-
)
|
|
|
|
| 1 |
# coding=utf-8
|
| 2 |
+
# Copyright 2023 BharatTech Tech Ecosystem Pvt. Ltd.All rights reserved.
|
| 3 |
|
| 4 |
""" PyTorch Bharatai model."""
|
| 5 |
import math
|
|
|
|
| 16 |
from transformers.modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_causal_attention_mask
|
| 17 |
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
| 18 |
from transformers.modeling_utils import PreTrainedModel
|
| 19 |
+
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13
|
| 20 |
from transformers.utils import (
|
| 21 |
add_start_docstrings,
|
| 22 |
add_start_docstrings_to_model_forward,
|
|
|
|
| 36 |
# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
|
| 37 |
# It means that the function will not be traced through and simply appear as a node in the graph.
|
| 38 |
if is_torch_fx_available():
|
| 39 |
+
if not is_torch_greater_or_equal_than_1_13:
|
| 40 |
+
import torch.fx
|
| 41 |
|
| 42 |
_prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
|
| 43 |
|
|
|
|
| 98 |
|
| 99 |
|
| 100 |
class BharataiRotaryEmbedding(nn.Module):
|
| 101 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
| 102 |
super().__init__()
|
| 103 |
|
| 104 |
self.dim = dim
|
|
|
|
| 136 |
class BharataiLinearScalingRotaryEmbedding(BharataiRotaryEmbedding):
|
| 137 |
"""BharataiRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
| 138 |
|
| 139 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
| 140 |
self.scaling_factor = scaling_factor
|
| 141 |
super().__init__(dim, max_position_embeddings, base, device)
|
| 142 |
|
|
|
|
| 155 |
class BharataiDynamicNTKScalingRotaryEmbedding(BharataiRotaryEmbedding):
|
| 156 |
"""BharataiRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
| 157 |
|
| 158 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
| 159 |
self.scaling_factor = scaling_factor
|
| 160 |
super().__init__(dim, max_position_embeddings, base, device)
|
| 161 |
|
|
|
|
| 896 |
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
| 897 |
|
| 898 |
if self.gradient_checkpointing and self.training:
|
| 899 |
+
if output_attentions:
|
| 900 |
+
layer_outputs = self._gradient_checkpointing_func(decoder_layer.__call__, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache)
|
| 901 |
+
else:
|
| 902 |
+
layer_outputs = self._gradient_checkpointing_func(decoder_layer.__call__, hidden_states, attention_mask, position_ids, past_key_value, None, use_cache)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 903 |
else:
|
| 904 |
+
layer_outputs = decoder_layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache)
|
| 905 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 906 |
|
| 907 |
hidden_states = layer_outputs[0]
|
| 908 |
|
|
|
|
| 1207 |
past_key_values=transformer_outputs.past_key_values,
|
| 1208 |
hidden_states=transformer_outputs.hidden_states,
|
| 1209 |
attentions=transformer_outputs.attentions,
|
| 1210 |
+
)
|