ToastyPigeon commited on
Commit
437ff8f
·
verified ·
1 Parent(s): d81fb9b

Fix configuration file

Browse files
Files changed (1) hide show
  1. configuration_gemmagain.py +186 -39
configuration_gemmagain.py CHANGED
@@ -1,39 +1,186 @@
1
- {
2
- "architectures": [
3
- "Gemma3ForCausalLM"
4
- ],
5
- "auto_map": {
6
- "AutoConfig": "configuration_gemmagain.GemmagainConfig",
7
- "AutoModelForCausalLM": "modeling_gemmagain.GemmagainForCausalLM"
8
- },
9
- "attention_bias": false,
10
- "attention_dropout": 0.0,
11
- "attn_logit_softcapping": null,
12
- "cache_implementation": "hybrid",
13
- "final_logit_softcapping": null,
14
- "head_dim": 256,
15
- "hidden_activation": "gelu_pytorch_tanh",
16
- "hidden_size": 2560,
17
- "initializer_range": 0.02,
18
- "intermediate_size": 10240,
19
- "max_position_embeddings": 131072,
20
- "model_type": "gemma3",
21
- "num_attention_heads": 8,
22
- "num_hidden_layers": 34,
23
- "num_key_value_heads": 4,
24
- "query_pre_attn_scalar": 256,
25
- "rms_norm_eps": 1e-06,
26
- "rope_local_base_freq": 10000.0,
27
- "rope_scaling": {
28
- "factor": 8.0,
29
- "rope_type": "linear"
30
- },
31
- "rope_theta": 1000000.0,
32
- "sliding_window": 1024,
33
- "sliding_window_pattern": 6,
34
- "layer_sequence": [[0, 34, 1]],
35
- "torch_dtype": "bfloat16",
36
- "use_cache": true,
37
- "vocab_size": 262208,
38
- "transformers_version": "4.51.0"
39
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Gemmagain model configuration - Gemma3 with layer looping support"""
16
+
17
+ from transformers.configuration_utils import PretrainedConfig, layer_type_validation
18
+ from transformers.modeling_rope_utils import rope_config_validation
19
+ from transformers.utils import logging
20
+
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+
25
+ class GemmagainConfig(PretrainedConfig):
26
+ r"""
27
+ Configuration class for Gemmagain - a Gemma3 text model with layer looping support.
28
+
29
+ This extends Gemma3TextConfig to add the `layer_sequence` parameter which controls
30
+ how layers are executed, allowing layers to be repeated multiple times.
31
+
32
+ Args:
33
+ vocab_size (`int`, *optional*, defaults to 262208):
34
+ Vocabulary size of the model.
35
+ hidden_size (`int`, *optional*, defaults to 2560):
36
+ Dimension of the hidden representations.
37
+ intermediate_size (`int`, *optional*, defaults to 10240):
38
+ Dimension of the MLP representations.
39
+ num_hidden_layers (`int`, *optional*, defaults to 34):
40
+ Number of hidden layers in the Transformer decoder.
41
+ num_attention_heads (`int`, *optional*, defaults to 8):
42
+ Number of attention heads for each attention layer.
43
+ num_key_value_heads (`int`, *optional*, defaults to 4):
44
+ Number of key_value heads for GQA.
45
+ head_dim (`int`, *optional*, defaults to 256):
46
+ The attention head dimension.
47
+ hidden_activation (`str`, *optional*, defaults to `"gelu_pytorch_tanh"`):
48
+ The activation function.
49
+ max_position_embeddings (`int`, *optional*, defaults to 131072):
50
+ Maximum sequence length.
51
+ layer_sequence (`list`, *optional*):
52
+ Order to execute layers. Defaults to all layers once.
53
+ Flexible format - each item can be:
54
+ - An integer: single layer index (e.g., 5 means layer 5)
55
+ - A 2-element list [start, end]: range of layers (e.g., [4, 20] means layers 4-19)
56
+ - A 3-element list [start, end, repeats]: range repeated N times
57
+ Examples:
58
+ - [[0, 34, 1]]: all 34 layers once
59
+ - [[0, 10], [10, 28, 2], [28, 34]]: layers 0-9, then 10-27 twice, then 28-33
60
+ layer_types (`list`, *optional*):
61
+ Attention pattern for each layer ("sliding_attention" or "full_attention").
62
+ sliding_window (`int`, *optional*, defaults to 1024):
63
+ Size of the sliding window for sliding attention layers.
64
+ rope_theta (`float`, *optional*, defaults to 1000000.0):
65
+ Base period for RoPE embeddings (global attention).
66
+ rope_local_base_freq (`float`, *optional*, defaults to 10000.0):
67
+ Base period for RoPE embeddings (local/sliding attention).
68
+ query_pre_attn_scalar (`float`, *optional*, defaults to 256):
69
+ Scaling factor for attention scores.
70
+ rms_norm_eps (`float`, *optional*, defaults to 1e-6):
71
+ Epsilon for RMS normalization.
72
+ attention_bias (`bool`, *optional*, defaults to False):
73
+ Whether to use bias in attention projections.
74
+ attention_dropout (`float`, *optional*, defaults to 0.0):
75
+ Dropout ratio for attention.
76
+ final_logit_softcapping (`float`, *optional*):
77
+ Softcapping for final logits.
78
+ attn_logit_softcapping (`float`, *optional*):
79
+ Softcapping for attention logits.
80
+ rope_scaling (`dict`, *optional*):
81
+ RoPE scaling configuration.
82
+ use_bidirectional_attention (`bool`, *optional*, defaults to False):
83
+ If True, use bidirectional attention instead of causal.
84
+ """
85
+
86
+ model_type = "gemma3"
87
+ keys_to_ignore_at_inference = ["past_key_values"]
88
+ base_model_tp_plan = {
89
+ "layers.*.self_attn.q_proj": "colwise",
90
+ "layers.*.self_attn.k_proj": "colwise",
91
+ "layers.*.self_attn.v_proj": "colwise",
92
+ "layers.*.self_attn.o_proj": "rowwise",
93
+ "layers.*.mlp.gate_proj": "colwise",
94
+ "layers.*.mlp.up_proj": "colwise",
95
+ "layers.*.mlp.down_proj": "rowwise",
96
+ }
97
+ base_model_pp_plan = {
98
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
99
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
100
+ "norm": (["hidden_states"], ["hidden_states"]),
101
+ }
102
+
103
+ def __init__(
104
+ self,
105
+ vocab_size=262_208,
106
+ hidden_size=2560,
107
+ intermediate_size=10240,
108
+ num_hidden_layers=34,
109
+ num_attention_heads=8,
110
+ num_key_value_heads=4,
111
+ head_dim=256,
112
+ hidden_activation="gelu_pytorch_tanh",
113
+ max_position_embeddings=131_072,
114
+ initializer_range=0.02,
115
+ rms_norm_eps=1e-6,
116
+ use_cache=True,
117
+ pad_token_id=0,
118
+ eos_token_id=1,
119
+ bos_token_id=2,
120
+ tie_word_embeddings=True,
121
+ rope_theta=1_000_000.0,
122
+ attention_bias=False,
123
+ attention_dropout=0.0,
124
+ query_pre_attn_scalar=256,
125
+ sliding_window=1024,
126
+ layer_types=None,
127
+ layer_sequence=None,
128
+ final_logit_softcapping=None,
129
+ attn_logit_softcapping=None,
130
+ rope_scaling=None,
131
+ rope_local_base_freq=10_000.0,
132
+ use_bidirectional_attention=False,
133
+ **kwargs,
134
+ ):
135
+ super().__init__(
136
+ pad_token_id=pad_token_id,
137
+ bos_token_id=bos_token_id,
138
+ eos_token_id=eos_token_id,
139
+ tie_word_embeddings=tie_word_embeddings,
140
+ **kwargs,
141
+ )
142
+ self.vocab_size = vocab_size
143
+ self.max_position_embeddings = max_position_embeddings
144
+ self.hidden_size = hidden_size
145
+ self.intermediate_size = intermediate_size
146
+ self.num_hidden_layers = num_hidden_layers
147
+ self.num_attention_heads = num_attention_heads
148
+ self.head_dim = head_dim
149
+ self.num_key_value_heads = num_key_value_heads
150
+ self.initializer_range = initializer_range
151
+ self.rms_norm_eps = rms_norm_eps
152
+ self.use_cache = use_cache
153
+ self.rope_theta = rope_theta
154
+ self.attention_bias = attention_bias
155
+ self.attention_dropout = attention_dropout
156
+ self.hidden_activation = hidden_activation
157
+ self.query_pre_attn_scalar = query_pre_attn_scalar
158
+ self.sliding_window = sliding_window
159
+ self.final_logit_softcapping = final_logit_softcapping
160
+ self.attn_logit_softcapping = attn_logit_softcapping
161
+ self.use_bidirectional_attention = use_bidirectional_attention
162
+
163
+ if use_bidirectional_attention:
164
+ self.sliding_window = (self.sliding_window // 2) + 1
165
+
166
+ self.rope_local_base_freq = rope_local_base_freq
167
+ self.rope_scaling = rope_scaling
168
+ rope_config_validation(self)
169
+
170
+ # Layer sequence for looping - defaults to all layers once
171
+ if layer_sequence is None:
172
+ layer_sequence = [[0, num_hidden_layers, 1]]
173
+ self.layer_sequence = layer_sequence
174
+
175
+ # Layer types (sliding vs full attention)
176
+ self._sliding_window_pattern = kwargs.get("sliding_window_pattern", 6)
177
+ self.layer_types = layer_types
178
+ if self.layer_types is None:
179
+ self.layer_types = [
180
+ "sliding_attention" if bool((i + 1) % self._sliding_window_pattern) else "full_attention"
181
+ for i in range(self.num_hidden_layers)
182
+ ]
183
+ layer_type_validation(self.layer_types, self.num_hidden_layers)
184
+
185
+
186
+ __all__ = ["GemmagainConfig"]