refactor changes
Browse files- modeling_cerule_gemma.py +7 -7
modeling_cerule_gemma.py
CHANGED
|
@@ -872,7 +872,7 @@ if is_torch_fx_available():
|
|
| 872 |
|
| 873 |
logger = logging.get_logger(__name__)
|
| 874 |
|
| 875 |
-
_CONFIG_FOR_DOC = "
|
| 876 |
|
| 877 |
|
| 878 |
def _get_unpad_data(attention_mask):
|
|
@@ -1003,7 +1003,7 @@ class GemmaAttention(nn.Module):
|
|
| 1003 |
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 1004 |
|
| 1005 |
# Ignore copy
|
| 1006 |
-
def __init__(self, config:
|
| 1007 |
super().__init__()
|
| 1008 |
self.config = config
|
| 1009 |
self.layer_idx = layer_idx
|
|
@@ -1396,7 +1396,7 @@ GEMMA_ATTENTION_CLASSES = {
|
|
| 1396 |
|
| 1397 |
# Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with LLAMA->GEMMA,Llama->Gemma
|
| 1398 |
class GemmaDecoderLayer(nn.Module):
|
| 1399 |
-
def __init__(self, config:
|
| 1400 |
super().__init__()
|
| 1401 |
self.hidden_size = config.hidden_size
|
| 1402 |
|
|
@@ -1480,7 +1480,7 @@ GEMMA_START_DOCSTRING = r"""
|
|
| 1480 |
and behavior.
|
| 1481 |
|
| 1482 |
Parameters:
|
| 1483 |
-
config ([`
|
| 1484 |
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
| 1485 |
load the weights associated with the model, only the configuration. Check out the
|
| 1486 |
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
|
@@ -1492,7 +1492,7 @@ GEMMA_START_DOCSTRING = r"""
|
|
| 1492 |
GEMMA_START_DOCSTRING,
|
| 1493 |
)
|
| 1494 |
class GemmaPreTrainedModel(PreTrainedModel):
|
| 1495 |
-
config_class =
|
| 1496 |
base_model_prefix = "model"
|
| 1497 |
supports_gradient_checkpointing = True
|
| 1498 |
_keep_in_fp32_modules = ["inv_freq", "rotary_emb", "cos_cached", "sin_cached"]
|
|
@@ -1618,7 +1618,7 @@ class GemmaModel(GemmaPreTrainedModel):
|
|
| 1618 |
config: GemmaConfig
|
| 1619 |
"""
|
| 1620 |
|
| 1621 |
-
def __init__(self, config:
|
| 1622 |
super().__init__(config)
|
| 1623 |
self.padding_idx = config.pad_token_id
|
| 1624 |
self.vocab_size = config.vocab_size
|
|
@@ -2155,7 +2155,7 @@ from .configuration_gemma import CeruleGemmaConfig
|
|
| 2155 |
class CeruleGemmaModel(CeruleMetaModel, GemmaModel):
|
| 2156 |
config_class = CeruleGemmaConfig
|
| 2157 |
|
| 2158 |
-
def __init__(self, config:
|
| 2159 |
super(CeruleGemmaModel, self).__init__(config)
|
| 2160 |
|
| 2161 |
|
|
|
|
| 872 |
|
| 873 |
logger = logging.get_logger(__name__)
|
| 874 |
|
| 875 |
+
_CONFIG_FOR_DOC = "CeruleGemmaConfig"
|
| 876 |
|
| 877 |
|
| 878 |
def _get_unpad_data(attention_mask):
|
|
|
|
| 1003 |
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 1004 |
|
| 1005 |
# Ignore copy
|
| 1006 |
+
def __init__(self, config: CeruleGemmaConfig, layer_idx: Optional[int] = None):
|
| 1007 |
super().__init__()
|
| 1008 |
self.config = config
|
| 1009 |
self.layer_idx = layer_idx
|
|
|
|
| 1396 |
|
| 1397 |
# Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with LLAMA->GEMMA,Llama->Gemma
|
| 1398 |
class GemmaDecoderLayer(nn.Module):
|
| 1399 |
+
def __init__(self, config: CeruleGemmaConfig, layer_idx: int):
|
| 1400 |
super().__init__()
|
| 1401 |
self.hidden_size = config.hidden_size
|
| 1402 |
|
|
|
|
| 1480 |
and behavior.
|
| 1481 |
|
| 1482 |
Parameters:
|
| 1483 |
+
config ([`CeruleGemmaConfig`]):
|
| 1484 |
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
| 1485 |
load the weights associated with the model, only the configuration. Check out the
|
| 1486 |
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
|
|
|
| 1492 |
GEMMA_START_DOCSTRING,
|
| 1493 |
)
|
| 1494 |
class GemmaPreTrainedModel(PreTrainedModel):
|
| 1495 |
+
config_class = CeruleGemmaConfig
|
| 1496 |
base_model_prefix = "model"
|
| 1497 |
supports_gradient_checkpointing = True
|
| 1498 |
_keep_in_fp32_modules = ["inv_freq", "rotary_emb", "cos_cached", "sin_cached"]
|
|
|
|
| 1618 |
config: GemmaConfig
|
| 1619 |
"""
|
| 1620 |
|
| 1621 |
+
def __init__(self, config: CeruleGemmaConfig):
|
| 1622 |
super().__init__(config)
|
| 1623 |
self.padding_idx = config.pad_token_id
|
| 1624 |
self.vocab_size = config.vocab_size
|
|
|
|
| 2155 |
class CeruleGemmaModel(CeruleMetaModel, GemmaModel):
|
| 2156 |
config_class = CeruleGemmaConfig
|
| 2157 |
|
| 2158 |
+
def __init__(self, config: CeruleGemmaConfig):
|
| 2159 |
super(CeruleGemmaModel, self).__init__(config)
|
| 2160 |
|
| 2161 |
|