coolpoodle commited on
Commit
2d58dce
·
verified ·
1 Parent(s): 89db275

Initial upload

Browse files
Files changed (4) hide show
  1. README.md +117 -0
  2. gate_projections.pt +3 -0
  3. loop_config.json +11 -0
  4. modeling_qwen_loop.py +380 -0
README.md ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ language:
4
+ - en
5
+ base_model:
6
+ - Qwen/Qwen3-0.6B
7
+ ---
8
+ # Qwen3-0.6B with Looped (Poodle) Attention
9
+ Hello world! I’m poodle, I wanted to share a open-source methodology of how I implemented loop attention into Qwen3-0.6B. I did not want to just hand you the weights so I also included the training script meant for qwen’s architecture.
10
+
11
+ I hope you enjoy!
12
+
13
+ This model implements **Loop Attention** on top of Qwen3-0.6B, a novel architecture that performs two forward passes through the attention mechanism:
14
+
15
+ 1. **Loop 1**: Captures global context using standard attention
16
+ 2. **Loop 2**: Mixes global context with local attention via learned gates
17
+
18
+ ## Results
19
+
20
+ | Model | Loss | Perplexity |
21
+ |-------|------|------------|
22
+ | Baseline Qwen3-0.6B | 3.7431 | 42.23 |
23
+ | **Loop Attention** | **3.5549** | **35.01** |
24
+ | Improvement | -0.1882 | -7.22 |
25
+
26
+ Loop Attention improves perplexity by **17%** on WikiText-2 validation set.
27
+
28
+ ## Architecture
29
+
30
+ Loop Attention adds a lightweight gating mechanism to each attention layer:
31
+ - **Gate Projection**: Linear layer mapping query states to a scalar gate value (0-1)
32
+ - **Trainable Parameters**: Only 57,792 parameters (gates only)
33
+ - **Base Model**: Frozen Qwen3-0.6B weights
34
+
35
+ The gate controls how much global context (from Loop 1) vs local attention (Loop 2) to use for each token.
36
+
37
+ ## Usage
38
+
39
+ ```python
40
+ import torch
41
+ from modeling_qwen_loop import Qwen3LoopForCausalLM
42
+ from transformers import AutoTokenizer
43
+
44
+ # Load model
45
+ model = Qwen3LoopForCausalLM.from_pretrained("Qwen/Qwen3-0.6B")
46
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
47
+ model = model.to(device)
48
+
49
+ # Load trained gates
50
+ gate_state = torch.load("gate_projections.pt", map_location=device)
51
+ for key, value in gate_state.items():
52
+ parts = key.split('.')
53
+ layer_idx = int(parts[1])
54
+ param_name = parts[-1]
55
+ if param_name == 'weight':
56
+ model.model.layers[layer_idx].self_attn.gate.weight.data = value.to(device)
57
+ elif param_name == 'bias':
58
+ model.model.layers[layer_idx].self_attn.gate.bias.data = value.to(device)
59
+
60
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
61
+ tokenizer.pad_token = tokenizer.eos_token
62
+
63
+ # Generate with Loop Attention (use_cache=False activates loops)
64
+ prompt = "The capital of France is"
65
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
66
+
67
+ model.eval()
68
+ with torch.no_grad():
69
+ output = model.generate(
70
+ input_ids=inputs.input_ids,
71
+ max_new_tokens=50,
72
+ do_sample=True,
73
+ temperature=0.7,
74
+ use_cache=False, # CRITICAL: Enables Loop Attention
75
+ pad_token_id=tokenizer.eos_token_id
76
+ )
77
+
78
+ print(tokenizer.decode(output[0], skip_special_tokens=True))
79
+ ```
80
+
81
+ ## Important Notes
82
+
83
+ - **use_cache=False** is required during generation to activate Loop Attention
84
+ - With `use_cache=True` (default), the model behaves like standard Qwen3-0.6B
85
+ - The base Qwen3-0.6B weights are not modified; only gate projections are trained
86
+
87
+ ## Training Details
88
+
89
+ - **Dataset**: WikiText-2
90
+ - **Epochs**: 3
91
+ - **Batch Size**: 64 (16 x 4 gradient accumulation)
92
+ - **Learning Rate**: 3e-4 with warmup and linear decay
93
+ - **Max Length**: 512 tokens
94
+ - **Training Time**: ~39 minutes on A100 80GB
95
+
96
+ ## Files
97
+
98
+ - `modeling_qwen_loop.py` - Loop Attention implementation
99
+ - `gate_projections.pt` - Trained gate weights (249KB)
100
+ - `loop_config.json` - Training configuration
101
+
102
+ ## Citation
103
+
104
+ If you use this model, please cite:
105
+
106
+ ```
107
+ @misc{qwen3-loop-attention,
108
+ title={Loop Attention for Qwen3-0.6B},
109
+ year={2025},
110
+ publisher={HuggingFace},
111
+ url={https://huggingface.co/coolpoodle/qwen3-0.6b-looped}
112
+ }
113
+ ```
114
+
115
+ ## License
116
+
117
+ This model inherits the license from [Qwen3-0.6B](https://huggingface.co/Qwen/Qwen3-0.6B).
gate_projections.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:007fc563c307fe1a06bfe14a1d20e3b5dc76f9e337a845cc28a5470a3223f596
3
+ size 249257
loop_config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "base_model": "/content/Qwen3-0.6B",
3
+ "loop_window_size": 64,
4
+ "num_layers": 28,
5
+ "num_heads": 16,
6
+ "head_dim": 128,
7
+ "final_val_loss": 3.6202090362707775,
8
+ "final_val_ppl": 37.34537124633789,
9
+ "training_epochs": 3,
10
+ "training_time_minutes": 38.990576179822284
11
+ }
modeling_qwen_loop.py ADDED
@@ -0,0 +1,380 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from transformers import AutoModelForCausalLM, AutoConfig
5
+ from transformers.models.qwen3.modeling_qwen3 import Qwen3Attention, apply_rotary_pos_emb, repeat_kv
6
+
7
+
8
+ class Qwen3LoopConfig:
9
+ def __init__(self, base_config, loop_window_size=64):
10
+ self.base_config = base_config
11
+ self.loop_window_size = loop_window_size
12
+
13
+ def __getattr__(self, name):
14
+ return getattr(self.base_config, name)
15
+
16
+
17
+ class LoopGate(nn.Module):
18
+ def __init__(self, num_heads, head_dim):
19
+ super().__init__()
20
+ # Initialize weights to near-zero random noise to break symmetry
21
+ self.weight = nn.Parameter(torch.randn(num_heads, head_dim) * 0.01)
22
+
23
+ # Initialize bias to +5.0, this is important for anyone tryna implement this cross-architecture, dont forget this.
24
+ # Sigmoid(5.0) ≈ 0.993
25
+ self.bias = nn.Parameter(torch.full((num_heads,), 5.0))
26
+
27
+ def forward(self, query_states):
28
+ # [batch, heads, seq, dim] -> [batch, heads, seq, 1]
29
+ gate_logits = torch.einsum('bhsd,hd->bhs', query_states, self.weight) + self.bias.view(1, -1, 1)
30
+ return torch.sigmoid(gate_logits).unsqueeze(-1)
31
+
32
+
33
+
34
+ # Loop Attention
35
+ class Qwen3LoopAttention(nn.Module):
36
+ def __init__(self, original_attn: Qwen3Attention, loop_window_size: int = 64):
37
+ super().__init__()
38
+ self.loop_window_size = loop_window_size
39
+ self.layer_idx = original_attn.layer_idx
40
+
41
+ # Get config values
42
+ config = original_attn.config
43
+ self.hidden_size = config.hidden_size
44
+ self.num_heads = config.num_attention_heads
45
+ self.head_dim = original_attn.head_dim
46
+ self.num_key_value_heads = config.num_key_value_heads
47
+ self.num_key_value_groups = original_attn.num_key_value_groups
48
+ self.scaling = original_attn.scaling
49
+ self.is_causal = original_attn.is_causal
50
+ # Qwen3 uses head_dim * num_heads which may differ from hidden_size
51
+ self.attn_hidden_size = self.num_heads * self.head_dim
52
+
53
+ # Share weights by reference (No extra memory)
54
+ self.q_proj = original_attn.q_proj
55
+ self.k_proj = original_attn.k_proj
56
+ self.v_proj = original_attn.v_proj
57
+ self.o_proj = original_attn.o_proj
58
+
59
+ # Qwen3 specific: q_norm and k_norm
60
+ self.q_norm = original_attn.q_norm
61
+ self.k_norm = original_attn.k_norm
62
+
63
+ # New Gate
64
+ self.gate = LoopGate(self.num_heads, self.head_dim)
65
+
66
+ # Loop State
67
+ self._loop_mode = 0
68
+ self._global_k = None
69
+ self._global_v = None
70
+
71
+ def forward(self, hidden_states, position_embeddings,
72
+ attention_mask=None, past_key_values=None,
73
+ cache_position=None, **kwargs):
74
+ bsz, q_len, _ = hidden_states.size()
75
+
76
+ query_states = self.q_proj(hidden_states)
77
+ key_states = self.k_proj(hidden_states)
78
+ value_states = self.v_proj(hidden_states)
79
+
80
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
81
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
82
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
83
+
84
+ # Qwen3: Apply Q/K normalization
85
+ query_states = self.q_norm(query_states)
86
+ key_states = self.k_norm(key_states)
87
+
88
+ # RoPE - Qwen3 passes position_embeddings from model level
89
+ cos, sin = position_embeddings
90
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
91
+
92
+ # Update KV Cache
93
+ if past_key_values is not None:
94
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
95
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
96
+
97
+ key_states_rpt = repeat_kv(key_states, self.num_key_value_groups)
98
+ value_states_rpt = repeat_kv(value_states, self.num_key_value_groups)
99
+
100
+
101
+ if self._loop_mode == 1:
102
+ # Loop 1: Capture Global Context
103
+ self._global_k = key_states_rpt.detach()
104
+ self._global_v = value_states_rpt.detach()
105
+
106
+ attn_output = F.scaled_dot_product_attention(
107
+ query_states, key_states_rpt, value_states_rpt,
108
+ attn_mask=attention_mask, is_causal=self.is_causal and attention_mask is None
109
+ )
110
+
111
+ elif self._loop_mode == 2:
112
+ # Loop 2: Mixed Attention
113
+ g = self.gate(query_states)
114
+
115
+ # Global (from cache)
116
+ attn_global = F.scaled_dot_product_attention(
117
+ query_states, self._global_k, self._global_v,
118
+ attn_mask=attention_mask, is_causal=self.is_causal and attention_mask is None
119
+ )
120
+
121
+ # Local (Windowed)
122
+ ids_q = torch.arange(q_len, device=query_states.device).unsqueeze(1)
123
+ ids_k = torch.arange(key_states.shape[2], device=query_states.device).unsqueeze(0)
124
+ mask_window = (ids_k <= ids_q) & (ids_k > (ids_q - self.loop_window_size))
125
+
126
+ # Create local attention mask
127
+ local_mask = torch.full(
128
+ (1, 1, q_len, key_states.shape[2]),
129
+ torch.finfo(query_states.dtype).min,
130
+ device=query_states.device,
131
+ dtype=query_states.dtype
132
+ )
133
+ local_mask.masked_fill_(mask_window, 0.0)
134
+
135
+ attn_local = F.scaled_dot_product_attention(
136
+ query_states, key_states_rpt, value_states_rpt,
137
+ attn_mask=local_mask, is_causal=False
138
+ )
139
+
140
+ # Mixing: If Bias=5.0, g ~ 1.0, so result is mostly global
141
+ attn_output = g * attn_global + (1.0 - g) * attn_local
142
+
143
+ else:
144
+ # Standard (for Inference/Generation fallback)
145
+ attn_output = F.scaled_dot_product_attention(
146
+ query_states, key_states_rpt, value_states_rpt,
147
+ attn_mask=attention_mask, is_causal=self.is_causal and attention_mask is None
148
+ )
149
+
150
+ attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, q_len, self.attn_hidden_size)
151
+ attn_output = self.o_proj(attn_output)
152
+
153
+ # Qwen3 expects (attn_output, attn_weights)
154
+ return attn_output, None
155
+
156
+
157
+ class Qwen3LoopForCausalLM(nn.Module):
158
+ """Wrapper that adds Loop Attention to Qwen3."""
159
+
160
+ def __init__(self, base_model, loop_window_size=64):
161
+ super().__init__()
162
+ self.model = base_model.model
163
+ self.lm_head = base_model.lm_head
164
+ self.config = base_model.config
165
+ self.loop_window_size = loop_window_size
166
+ self.generation_config = base_model.generation_config
167
+
168
+ # Replace attention layers with loop versions
169
+ for layer in self.model.layers:
170
+ if not isinstance(layer.self_attn, Qwen3LoopAttention):
171
+ new_attn = Qwen3LoopAttention(layer.self_attn, loop_window_size)
172
+ new_attn.to(layer.self_attn.q_proj.weight.device)
173
+ new_attn.to(layer.self_attn.q_proj.weight.dtype)
174
+ layer.self_attn = new_attn
175
+
176
+ @classmethod
177
+ def from_pretrained(cls, model_path, loop_window_size=64, **kwargs):
178
+ base = AutoModelForCausalLM.from_pretrained(model_path, **kwargs)
179
+ return cls(base, loop_window_size)
180
+
181
+ def forward(self, input_ids=None, attention_mask=None, position_ids=None,
182
+ past_key_values=None, inputs_embeds=None, labels=None,
183
+ use_cache=None, output_attentions=None, output_hidden_states=None,
184
+ return_dict=None, cache_position=None, **kwargs):
185
+
186
+ if use_cache or (use_cache is None and self.config.use_cache and not self.training):
187
+ for layer in self.model.layers:
188
+ layer.self_attn._loop_mode = 0
189
+ return self._forward_standard(
190
+ input_ids=input_ids,
191
+ attention_mask=attention_mask,
192
+ position_ids=position_ids,
193
+ past_key_values=past_key_values,
194
+ inputs_embeds=inputs_embeds,
195
+ labels=labels,
196
+ use_cache=use_cache,
197
+ output_attentions=output_attentions,
198
+ output_hidden_states=output_hidden_states,
199
+ return_dict=return_dict,
200
+ cache_position=cache_position,
201
+ **kwargs
202
+ )
203
+
204
+ for layer in self.model.layers:
205
+ layer.self_attn._loop_mode = 1
206
+ with torch.no_grad():
207
+ self._forward_standard(
208
+ input_ids=input_ids,
209
+ attention_mask=attention_mask,
210
+ position_ids=position_ids,
211
+ past_key_values=None,
212
+ inputs_embeds=inputs_embeds,
213
+ use_cache=False,
214
+ **kwargs
215
+ )
216
+
217
+ for layer in self.model.layers:
218
+ layer.self_attn._loop_mode = 2
219
+ outputs = self._forward_standard(
220
+ input_ids=input_ids,
221
+ attention_mask=attention_mask,
222
+ position_ids=position_ids,
223
+ past_key_values=None,
224
+ inputs_embeds=inputs_embeds,
225
+ labels=labels,
226
+ use_cache=False,
227
+ output_attentions=output_attentions,
228
+ output_hidden_states=output_hidden_states,
229
+ return_dict=return_dict,
230
+ **kwargs
231
+ )
232
+
233
+ for layer in self.model.layers:
234
+ layer.self_attn._loop_mode = 0
235
+ layer.self_attn._global_k = None
236
+ layer.self_attn._global_v = None
237
+
238
+ return outputs
239
+
240
+ def _forward_standard(self, input_ids=None, attention_mask=None, position_ids=None,
241
+ past_key_values=None, inputs_embeds=None, labels=None,
242
+ use_cache=None, output_attentions=None, output_hidden_states=None,
243
+ return_dict=None, cache_position=None, **kwargs):
244
+ """Standard forward pass through the model."""
245
+ from transformers.modeling_outputs import CausalLMOutputWithPast
246
+
247
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
248
+
249
+ # Get hidden states from model
250
+ outputs = self.model(
251
+ input_ids=input_ids,
252
+ attention_mask=attention_mask,
253
+ position_ids=position_ids,
254
+ past_key_values=past_key_values,
255
+ inputs_embeds=inputs_embeds,
256
+ use_cache=use_cache,
257
+ output_attentions=output_attentions,
258
+ output_hidden_states=output_hidden_states,
259
+ return_dict=True,
260
+ cache_position=cache_position,
261
+ )
262
+
263
+ hidden_states = outputs.last_hidden_state
264
+ logits = self.lm_head(hidden_states)
265
+
266
+ loss = None
267
+ if labels is not None:
268
+ shift_logits = logits[..., :-1, :].contiguous()
269
+ shift_labels = labels[..., 1:].contiguous()
270
+ loss = F.cross_entropy(
271
+ shift_logits.view(-1, shift_logits.size(-1)),
272
+ shift_labels.view(-1),
273
+ ignore_index=-100
274
+ )
275
+
276
+ if not return_dict:
277
+ output = (logits,) + outputs[1:]
278
+ return (loss,) + output if loss is not None else output
279
+
280
+ return CausalLMOutputWithPast(
281
+ loss=loss,
282
+ logits=logits,
283
+ past_key_values=outputs.past_key_values,
284
+ hidden_states=outputs.hidden_states,
285
+ attentions=outputs.attentions,
286
+ )
287
+
288
+ def generate(self, input_ids=None, **kwargs):
289
+ """Generate text - always uses standard attention."""
290
+ for layer in self.model.layers:
291
+ layer.self_attn._loop_mode = 0
292
+ layer.self_attn._global_k = None
293
+ layer.self_attn._global_v = None
294
+
295
+ # Build a temporary wrapper that has the full generate() functionality
296
+ # by using the base model architecture
297
+ from transformers import AutoModelForCausalLM
298
+
299
+ # Create a simple generation loop
300
+ device = input_ids.device
301
+ max_new_tokens = kwargs.get('max_new_tokens', 50)
302
+ temperature = kwargs.get('temperature', 1.0)
303
+ do_sample = kwargs.get('do_sample', False)
304
+ top_p = kwargs.get('top_p', 1.0)
305
+ pad_token_id = kwargs.get('pad_token_id', self.config.eos_token_id)
306
+ eos_token_id = kwargs.get('eos_token_id', self.config.eos_token_id)
307
+
308
+ generated = input_ids.clone()
309
+
310
+ for _ in range(max_new_tokens):
311
+ with torch.no_grad():
312
+ outputs = self(input_ids=generated, use_cache=True)
313
+ next_token_logits = outputs.logits[:, -1, :]
314
+
315
+ if do_sample and temperature > 0:
316
+ next_token_logits = next_token_logits / temperature
317
+ if top_p < 1.0:
318
+ sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
319
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
320
+ sorted_indices_to_remove = cumulative_probs > top_p
321
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
322
+ sorted_indices_to_remove[..., 0] = False
323
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
324
+ next_token_logits[indices_to_remove] = float('-inf')
325
+
326
+ probs = F.softmax(next_token_logits, dim=-1)
327
+ next_token = torch.multinomial(probs, num_samples=1)
328
+ else:
329
+ next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
330
+
331
+ generated = torch.cat([generated, next_token], dim=-1)
332
+
333
+ if eos_token_id is not None and (next_token == eos_token_id).all():
334
+ break
335
+
336
+ return generated
337
+
338
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
339
+ attention_mask=None, inputs_embeds=None,
340
+ cache_position=None, **kwargs):
341
+ """Prepare inputs for generation step."""
342
+ if past_key_values is not None:
343
+ if inputs_embeds is not None:
344
+ input_ids = input_ids[:, -cache_position.shape[0]:]
345
+ elif input_ids.shape[1] != cache_position.shape[0]:
346
+ input_ids = input_ids[:, cache_position]
347
+
348
+ position_ids = kwargs.get("position_ids", None)
349
+ if attention_mask is not None and position_ids is None:
350
+ position_ids = attention_mask.long().cumsum(-1) - 1
351
+ position_ids.masked_fill_(attention_mask == 0, 1)
352
+ if past_key_values:
353
+ position_ids = position_ids[:, -input_ids.shape[1]:]
354
+
355
+ model_inputs = {
356
+ "input_ids": input_ids,
357
+ "position_ids": position_ids,
358
+ "cache_position": cache_position,
359
+ "past_key_values": past_key_values,
360
+ "use_cache": kwargs.get("use_cache", True),
361
+ "attention_mask": attention_mask,
362
+ }
363
+ return model_inputs
364
+
365
+ def enable_gate_training_only(self):
366
+ """Freeze all parameters except gates."""
367
+ self.requires_grad_(False)
368
+ for layer in self.model.layers:
369
+ layer.self_attn.gate.requires_grad_(True)
370
+
371
+ trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
372
+ total = sum(p.numel() for p in self.parameters())
373
+ print(f"Trainable: {trainable:,} / {total:,} ({100*trainable/total:.4f}%)")
374
+
375
+ def get_gate_parameters(self):
376
+ """Return list of gate parameters for optimizer."""
377
+ params = []
378
+ for layer in self.model.layers:
379
+ params.extend(layer.self_attn.gate.parameters())
380
+ return params