fixing import for modeling_gemma.py
Browse files- modeling_gemma.py +10 -31
modeling_gemma.py
CHANGED
|
@@ -27,19 +27,19 @@ import torch.utils.checkpoint
|
|
| 27 |
from torch import nn
|
| 28 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 29 |
|
| 30 |
-
from
|
| 31 |
-
from
|
| 32 |
-
from
|
| 33 |
-
from
|
| 34 |
-
from
|
| 35 |
BaseModelOutputWithPast,
|
| 36 |
CausalLMOutputWithPast,
|
| 37 |
SequenceClassifierOutputWithPast,
|
| 38 |
TokenClassifierOutput,
|
| 39 |
)
|
| 40 |
-
from
|
| 41 |
-
from
|
| 42 |
-
from
|
| 43 |
add_start_docstrings,
|
| 44 |
add_start_docstrings_to_model_forward,
|
| 45 |
is_flash_attn_greater_or_equal_2_10,
|
|
@@ -47,7 +47,7 @@ from ...utils import (
|
|
| 47 |
replace_return_docstrings,
|
| 48 |
)
|
| 49 |
from .configuration_gemma import GemmaConfig
|
| 50 |
-
|
| 51 |
|
| 52 |
logger = logging.get_logger(__name__)
|
| 53 |
|
|
@@ -105,27 +105,6 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
|
|
| 105 |
|
| 106 |
return causal_mask
|
| 107 |
|
| 108 |
-
|
| 109 |
-
class GemmaRMSNorm(nn.Module):
|
| 110 |
-
def __init__(self, dim: int, eps: float = 1e-6):
|
| 111 |
-
super().__init__()
|
| 112 |
-
self.eps = eps
|
| 113 |
-
self.weight = nn.Parameter(torch.zeros(dim))
|
| 114 |
-
|
| 115 |
-
def _norm(self, x):
|
| 116 |
-
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
| 117 |
-
|
| 118 |
-
def forward(self, x):
|
| 119 |
-
output = self._norm(x.float())
|
| 120 |
-
# Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
|
| 121 |
-
# See https://github.com/huggingface/transformers/pull/29402
|
| 122 |
-
output = output * (1.0 + self.weight.float())
|
| 123 |
-
return output.type_as(x)
|
| 124 |
-
|
| 125 |
-
def extra_repr(self):
|
| 126 |
-
return f"{tuple(self.weight.shape)}, eps={self.eps}"
|
| 127 |
-
|
| 128 |
-
|
| 129 |
ALL_LAYERNORM_LAYERS.append(GemmaRMSNorm)
|
| 130 |
|
| 131 |
|
|
@@ -528,7 +507,7 @@ class GemmaSdpaAttention(GemmaAttention):
|
|
| 528 |
|
| 529 |
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
| 530 |
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
| 531 |
-
if causal_mask is not None:
|
| 532 |
query_states = query_states.contiguous()
|
| 533 |
key_states = key_states.contiguous()
|
| 534 |
value_states = value_states.contiguous()
|
|
|
|
| 27 |
from torch import nn
|
| 28 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 29 |
|
| 30 |
+
from transformers.activations import ACT2FN
|
| 31 |
+
from transformers.cache_utils import Cache, DynamicCache, StaticCache
|
| 32 |
+
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
| 33 |
+
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
| 34 |
+
from transformers.modeling_outputs import (
|
| 35 |
BaseModelOutputWithPast,
|
| 36 |
CausalLMOutputWithPast,
|
| 37 |
SequenceClassifierOutputWithPast,
|
| 38 |
TokenClassifierOutput,
|
| 39 |
)
|
| 40 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 41 |
+
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
|
| 42 |
+
from transformers.utils import (
|
| 43 |
add_start_docstrings,
|
| 44 |
add_start_docstrings_to_model_forward,
|
| 45 |
is_flash_attn_greater_or_equal_2_10,
|
|
|
|
| 47 |
replace_return_docstrings,
|
| 48 |
)
|
| 49 |
from .configuration_gemma import GemmaConfig
|
| 50 |
+
from transformers.models.gemma.modeling_gemma import GemmaRMSNorm
|
| 51 |
|
| 52 |
logger = logging.get_logger(__name__)
|
| 53 |
|
|
|
|
| 105 |
|
| 106 |
return causal_mask
|
| 107 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
ALL_LAYERNORM_LAYERS.append(GemmaRMSNorm)
|
| 109 |
|
| 110 |
|
|
|
|
| 507 |
|
| 508 |
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
| 509 |
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
| 510 |
+
if query_states.device.type == "cuda" and causal_mask is not None:
|
| 511 |
query_states = query_states.contiguous()
|
| 512 |
key_states = key_states.contiguous()
|
| 513 |
value_states = value_states.contiguous()
|