ToastyPigeon commited on
Commit
956ac81
·
verified ·
1 Parent(s): b2e6d0f

Upload modeling_gemmagain.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_gemmagain.py +17 -17
modeling_gemmagain.py CHANGED
@@ -13,7 +13,7 @@
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  """
16
- Gemmagain - Gemma3 text model with layer looping support (wrapper approach).
17
 
18
  This model allows running the same physical layers multiple times in sequence,
19
  enabling parameter-efficient deep networks. Compatible with standard Gemma3 weights.
@@ -42,9 +42,9 @@ from transformers.utils import TransformersKwargs, auto_docstring, can_return_tu
42
  from transformers.utils.deprecation import deprecate_kwarg
43
 
44
  try:
45
- from .configuration_gemmagain import GemmagainConfig
46
  except ImportError:
47
- from configuration_gemmagain import GemmagainConfig
48
 
49
 
50
  logger = logging.get_logger(__name__)
@@ -64,7 +64,7 @@ class Gemma3TextScaledWordEmbedding(nn.Embedding):
64
 
65
 
66
  class Gemma3MLP(nn.Module):
67
- def __init__(self, config: GemmagainConfig):
68
  super().__init__()
69
  self.config = config
70
  self.hidden_size = config.hidden_size
@@ -101,7 +101,7 @@ class Gemma3RMSNorm(nn.Module):
101
  class Gemma3RotaryEmbedding(nn.Module):
102
  inv_freq: torch.Tensor
103
 
104
- def __init__(self, config: GemmagainConfig, device=None):
105
  super().__init__()
106
  if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
107
  self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
@@ -195,7 +195,7 @@ def eager_attention_forward(
195
  class Gemma3Attention(nn.Module):
196
  """Multi-headed attention with support for virtual layer index."""
197
 
198
- def __init__(self, config: GemmagainConfig, layer_idx: int):
199
  super().__init__()
200
  self.is_sliding = config.layer_types[layer_idx] == "sliding_attention"
201
  self.config = config
@@ -276,7 +276,7 @@ class Gemma3Attention(nn.Module):
276
 
277
 
278
  class Gemma3DecoderLayer(GradientCheckpointingLayer):
279
- def __init__(self, config: GemmagainConfig, layer_idx: int):
280
  super().__init__()
281
  self.config = config
282
  self.hidden_size = config.hidden_size
@@ -337,8 +337,8 @@ class Gemma3DecoderLayer(GradientCheckpointingLayer):
337
 
338
 
339
  @auto_docstring
340
- class GemmagainPreTrainedModel(PreTrainedModel):
341
- config_class = GemmagainConfig
342
  base_model_prefix = "model"
343
  supports_gradient_checkpointing = True
344
  _no_split_modules = ["Gemma3DecoderLayer"]
@@ -388,8 +388,8 @@ def _bidirectional_window_overlay(sliding_window: int) -> Callable[[int, int, in
388
 
389
 
390
  @auto_docstring
391
- class GemmagainModel(GemmagainPreTrainedModel):
392
- def __init__(self, config: GemmagainConfig):
393
  super().__init__(config)
394
  self.padding_idx = config.pad_token_id
395
  self.vocab_size = config.vocab_size
@@ -514,14 +514,14 @@ class GemmagainModel(GemmagainPreTrainedModel):
514
 
515
 
516
  @auto_docstring
517
- class GemmagainForCausalLM(GemmagainPreTrainedModel, GenerationMixin):
518
  _tied_weights_keys = ["lm_head.weight"]
519
  _tp_plan = {"lm_head": "colwise_rep"}
520
  _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
521
 
522
- def __init__(self, config: GemmagainConfig):
523
  super().__init__(config)
524
- self.model = GemmagainModel(config)
525
  self.vocab_size = config.vocab_size
526
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
527
 
@@ -583,7 +583,7 @@ class GemmagainForCausalLM(GemmagainPreTrainedModel, GenerationMixin):
583
 
584
 
585
  __all__ = [
586
- "GemmagainForCausalLM",
587
- "GemmagainModel",
588
- "GemmagainPreTrainedModel",
589
  ]
 
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  """
16
+ Gemma3 - Gemma3 text model with layer looping support (wrapper approach).
17
 
18
  This model allows running the same physical layers multiple times in sequence,
19
  enabling parameter-efficient deep networks. Compatible with standard Gemma3 weights.
 
42
  from transformers.utils.deprecation import deprecate_kwarg
43
 
44
  try:
45
+ from .configuration_gemmagain import Gemma3Config
46
  except ImportError:
47
+ from configuration_gemmagain import Gemma3Config
48
 
49
 
50
  logger = logging.get_logger(__name__)
 
64
 
65
 
66
  class Gemma3MLP(nn.Module):
67
+ def __init__(self, config: Gemma3Config):
68
  super().__init__()
69
  self.config = config
70
  self.hidden_size = config.hidden_size
 
101
  class Gemma3RotaryEmbedding(nn.Module):
102
  inv_freq: torch.Tensor
103
 
104
+ def __init__(self, config: Gemma3Config, device=None):
105
  super().__init__()
106
  if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
107
  self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
 
195
  class Gemma3Attention(nn.Module):
196
  """Multi-headed attention with support for virtual layer index."""
197
 
198
+ def __init__(self, config: Gemma3Config, layer_idx: int):
199
  super().__init__()
200
  self.is_sliding = config.layer_types[layer_idx] == "sliding_attention"
201
  self.config = config
 
276
 
277
 
278
  class Gemma3DecoderLayer(GradientCheckpointingLayer):
279
+ def __init__(self, config: Gemma3Config, layer_idx: int):
280
  super().__init__()
281
  self.config = config
282
  self.hidden_size = config.hidden_size
 
337
 
338
 
339
  @auto_docstring
340
+ class Gemma3PreTrainedModel(PreTrainedModel):
341
+ config_class = Gemma3Config
342
  base_model_prefix = "model"
343
  supports_gradient_checkpointing = True
344
  _no_split_modules = ["Gemma3DecoderLayer"]
 
388
 
389
 
390
  @auto_docstring
391
+ class Gemma3Model(Gemma3PreTrainedModel):
392
+ def __init__(self, config: Gemma3Config):
393
  super().__init__(config)
394
  self.padding_idx = config.pad_token_id
395
  self.vocab_size = config.vocab_size
 
514
 
515
 
516
  @auto_docstring
517
+ class Gemma3ForCausalLM(Gemma3PreTrainedModel, GenerationMixin):
518
  _tied_weights_keys = ["lm_head.weight"]
519
  _tp_plan = {"lm_head": "colwise_rep"}
520
  _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
521
 
522
+ def __init__(self, config: Gemma3Config):
523
  super().__init__(config)
524
+ self.model = Gemma3Model(config)
525
  self.vocab_size = config.vocab_size
526
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
527
 
 
583
 
584
 
585
  __all__ = [
586
+ "Gemma3ForCausalLM",
587
+ "Gemma3Model",
588
+ "Gemma3PreTrainedModel",
589
  ]