Upload modeling_gemmagain.py with huggingface_hub
Browse files- modeling_gemmagain.py +17 -17
modeling_gemmagain.py
CHANGED
|
@@ -13,7 +13,7 @@
|
|
| 13 |
# See the License for the specific language governing permissions and
|
| 14 |
# limitations under the License.
|
| 15 |
"""
|
| 16 |
-
|
| 17 |
|
| 18 |
This model allows running the same physical layers multiple times in sequence,
|
| 19 |
enabling parameter-efficient deep networks. Compatible with standard Gemma3 weights.
|
|
@@ -42,9 +42,9 @@ from transformers.utils import TransformersKwargs, auto_docstring, can_return_tu
|
|
| 42 |
from transformers.utils.deprecation import deprecate_kwarg
|
| 43 |
|
| 44 |
try:
|
| 45 |
-
from .configuration_gemmagain import
|
| 46 |
except ImportError:
|
| 47 |
-
from configuration_gemmagain import
|
| 48 |
|
| 49 |
|
| 50 |
logger = logging.get_logger(__name__)
|
|
@@ -64,7 +64,7 @@ class Gemma3TextScaledWordEmbedding(nn.Embedding):
|
|
| 64 |
|
| 65 |
|
| 66 |
class Gemma3MLP(nn.Module):
|
| 67 |
-
def __init__(self, config:
|
| 68 |
super().__init__()
|
| 69 |
self.config = config
|
| 70 |
self.hidden_size = config.hidden_size
|
|
@@ -101,7 +101,7 @@ class Gemma3RMSNorm(nn.Module):
|
|
| 101 |
class Gemma3RotaryEmbedding(nn.Module):
|
| 102 |
inv_freq: torch.Tensor
|
| 103 |
|
| 104 |
-
def __init__(self, config:
|
| 105 |
super().__init__()
|
| 106 |
if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
|
| 107 |
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
|
@@ -195,7 +195,7 @@ def eager_attention_forward(
|
|
| 195 |
class Gemma3Attention(nn.Module):
|
| 196 |
"""Multi-headed attention with support for virtual layer index."""
|
| 197 |
|
| 198 |
-
def __init__(self, config:
|
| 199 |
super().__init__()
|
| 200 |
self.is_sliding = config.layer_types[layer_idx] == "sliding_attention"
|
| 201 |
self.config = config
|
|
@@ -276,7 +276,7 @@ class Gemma3Attention(nn.Module):
|
|
| 276 |
|
| 277 |
|
| 278 |
class Gemma3DecoderLayer(GradientCheckpointingLayer):
|
| 279 |
-
def __init__(self, config:
|
| 280 |
super().__init__()
|
| 281 |
self.config = config
|
| 282 |
self.hidden_size = config.hidden_size
|
|
@@ -337,8 +337,8 @@ class Gemma3DecoderLayer(GradientCheckpointingLayer):
|
|
| 337 |
|
| 338 |
|
| 339 |
@auto_docstring
|
| 340 |
-
class
|
| 341 |
-
config_class =
|
| 342 |
base_model_prefix = "model"
|
| 343 |
supports_gradient_checkpointing = True
|
| 344 |
_no_split_modules = ["Gemma3DecoderLayer"]
|
|
@@ -388,8 +388,8 @@ def _bidirectional_window_overlay(sliding_window: int) -> Callable[[int, int, in
|
|
| 388 |
|
| 389 |
|
| 390 |
@auto_docstring
|
| 391 |
-
class
|
| 392 |
-
def __init__(self, config:
|
| 393 |
super().__init__(config)
|
| 394 |
self.padding_idx = config.pad_token_id
|
| 395 |
self.vocab_size = config.vocab_size
|
|
@@ -514,14 +514,14 @@ class GemmagainModel(GemmagainPreTrainedModel):
|
|
| 514 |
|
| 515 |
|
| 516 |
@auto_docstring
|
| 517 |
-
class
|
| 518 |
_tied_weights_keys = ["lm_head.weight"]
|
| 519 |
_tp_plan = {"lm_head": "colwise_rep"}
|
| 520 |
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
| 521 |
|
| 522 |
-
def __init__(self, config:
|
| 523 |
super().__init__(config)
|
| 524 |
-
self.model =
|
| 525 |
self.vocab_size = config.vocab_size
|
| 526 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 527 |
|
|
@@ -583,7 +583,7 @@ class GemmagainForCausalLM(GemmagainPreTrainedModel, GenerationMixin):
|
|
| 583 |
|
| 584 |
|
| 585 |
__all__ = [
|
| 586 |
-
"
|
| 587 |
-
"
|
| 588 |
-
"
|
| 589 |
]
|
|
|
|
| 13 |
# See the License for the specific language governing permissions and
|
| 14 |
# limitations under the License.
|
| 15 |
"""
|
| 16 |
+
Gemma3 - Gemma3 text model with layer looping support (wrapper approach).
|
| 17 |
|
| 18 |
This model allows running the same physical layers multiple times in sequence,
|
| 19 |
enabling parameter-efficient deep networks. Compatible with standard Gemma3 weights.
|
|
|
|
| 42 |
from transformers.utils.deprecation import deprecate_kwarg
|
| 43 |
|
| 44 |
try:
|
| 45 |
+
from .configuration_gemmagain import Gemma3Config
|
| 46 |
except ImportError:
|
| 47 |
+
from configuration_gemmagain import Gemma3Config
|
| 48 |
|
| 49 |
|
| 50 |
logger = logging.get_logger(__name__)
|
|
|
|
| 64 |
|
| 65 |
|
| 66 |
class Gemma3MLP(nn.Module):
|
| 67 |
+
def __init__(self, config: Gemma3Config):
|
| 68 |
super().__init__()
|
| 69 |
self.config = config
|
| 70 |
self.hidden_size = config.hidden_size
|
|
|
|
| 101 |
class Gemma3RotaryEmbedding(nn.Module):
|
| 102 |
inv_freq: torch.Tensor
|
| 103 |
|
| 104 |
+
def __init__(self, config: Gemma3Config, device=None):
|
| 105 |
super().__init__()
|
| 106 |
if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
|
| 107 |
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
|
|
|
| 195 |
class Gemma3Attention(nn.Module):
|
| 196 |
"""Multi-headed attention with support for virtual layer index."""
|
| 197 |
|
| 198 |
+
def __init__(self, config: Gemma3Config, layer_idx: int):
|
| 199 |
super().__init__()
|
| 200 |
self.is_sliding = config.layer_types[layer_idx] == "sliding_attention"
|
| 201 |
self.config = config
|
|
|
|
| 276 |
|
| 277 |
|
| 278 |
class Gemma3DecoderLayer(GradientCheckpointingLayer):
|
| 279 |
+
def __init__(self, config: Gemma3Config, layer_idx: int):
|
| 280 |
super().__init__()
|
| 281 |
self.config = config
|
| 282 |
self.hidden_size = config.hidden_size
|
|
|
|
| 337 |
|
| 338 |
|
| 339 |
@auto_docstring
|
| 340 |
+
class Gemma3PreTrainedModel(PreTrainedModel):
|
| 341 |
+
config_class = Gemma3Config
|
| 342 |
base_model_prefix = "model"
|
| 343 |
supports_gradient_checkpointing = True
|
| 344 |
_no_split_modules = ["Gemma3DecoderLayer"]
|
|
|
|
| 388 |
|
| 389 |
|
| 390 |
@auto_docstring
|
| 391 |
+
class Gemma3Model(Gemma3PreTrainedModel):
|
| 392 |
+
def __init__(self, config: Gemma3Config):
|
| 393 |
super().__init__(config)
|
| 394 |
self.padding_idx = config.pad_token_id
|
| 395 |
self.vocab_size = config.vocab_size
|
|
|
|
| 514 |
|
| 515 |
|
| 516 |
@auto_docstring
|
| 517 |
+
class Gemma3ForCausalLM(Gemma3PreTrainedModel, GenerationMixin):
|
| 518 |
_tied_weights_keys = ["lm_head.weight"]
|
| 519 |
_tp_plan = {"lm_head": "colwise_rep"}
|
| 520 |
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
| 521 |
|
| 522 |
+
def __init__(self, config: Gemma3Config):
|
| 523 |
super().__init__(config)
|
| 524 |
+
self.model = Gemma3Model(config)
|
| 525 |
self.vocab_size = config.vocab_size
|
| 526 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 527 |
|
|
|
|
| 583 |
|
| 584 |
|
| 585 |
__all__ = [
|
| 586 |
+
"Gemma3ForCausalLM",
|
| 587 |
+
"Gemma3Model",
|
| 588 |
+
"Gemma3PreTrainedModel",
|
| 589 |
]
|