coolpoodle commited on
Commit
29fdd49
·
verified ·
1 Parent(s): b871d11

Update modeling_qwen_loop.py

Browse files

More comments yay!

Also can now save as .bin & .pt

Files changed (1) hide show
  1. modeling_qwen_loop.py +52 -8
modeling_qwen_loop.py CHANGED
@@ -6,6 +6,7 @@ from transformers.models.qwen3.modeling_qwen3 import Qwen3Attention, apply_rotar
6
 
7
 
8
  class Qwen3LoopConfig:
 
9
  def __init__(self, base_config, loop_window_size=64):
10
  self.base_config = base_config
11
  self.loop_window_size = loop_window_size
@@ -13,6 +14,7 @@ class Qwen3LoopConfig:
13
  def __getattr__(self, name):
14
  return getattr(self.base_config, name)
15
 
 
16
 
17
  class LoopGate(nn.Module):
18
  def __init__(self, num_heads, head_dim):
@@ -20,8 +22,10 @@ class LoopGate(nn.Module):
20
  # Initialize weights to near-zero random noise to break symmetry
21
  self.weight = nn.Parameter(torch.randn(num_heads, head_dim) * 0.01)
22
 
23
- # Initialize bias to +5.0, this is important for anyone tryna implement this cross-architecture, dont forget this.
24
  # Sigmoid(5.0) ≈ 0.993
 
 
25
  self.bias = nn.Parameter(torch.full((num_heads,), 5.0))
26
 
27
  def forward(self, query_states):
@@ -31,7 +35,8 @@ class LoopGate(nn.Module):
31
 
32
 
33
 
34
- # Loop Attention
 
35
  class Qwen3LoopAttention(nn.Module):
36
  def __init__(self, original_attn: Qwen3Attention, loop_window_size: int = 64):
37
  super().__init__()
@@ -73,6 +78,7 @@ class Qwen3LoopAttention(nn.Module):
73
  cache_position=None, **kwargs):
74
  bsz, q_len, _ = hidden_states.size()
75
 
 
76
  query_states = self.q_proj(hidden_states)
77
  key_states = self.k_proj(hidden_states)
78
  value_states = self.v_proj(hidden_states)
@@ -97,7 +103,6 @@ class Qwen3LoopAttention(nn.Module):
97
  key_states_rpt = repeat_kv(key_states, self.num_key_value_groups)
98
  value_states_rpt = repeat_kv(value_states, self.num_key_value_groups)
99
 
100
-
101
  if self._loop_mode == 1:
102
  # Loop 1: Capture Global Context
103
  self._global_k = key_states_rpt.detach()
@@ -112,13 +117,12 @@ class Qwen3LoopAttention(nn.Module):
112
  # Loop 2: Mixed Attention
113
  g = self.gate(query_states)
114
 
115
- # Global (from cache)
116
  attn_global = F.scaled_dot_product_attention(
117
  query_states, self._global_k, self._global_v,
118
  attn_mask=attention_mask, is_causal=self.is_causal and attention_mask is None
119
  )
120
 
121
- # Local (Windowed)
122
  ids_q = torch.arange(q_len, device=query_states.device).unsqueeze(1)
123
  ids_k = torch.arange(key_states.shape[2], device=query_states.device).unsqueeze(0)
124
  mask_window = (ids_k <= ids_q) & (ids_k > (ids_q - self.loop_window_size))
@@ -137,7 +141,7 @@ class Qwen3LoopAttention(nn.Module):
137
  attn_mask=local_mask, is_causal=False
138
  )
139
 
140
- # Mixing: If Bias=5.0, g ~ 1.0, so result is mostly global
141
  attn_output = g * attn_global + (1.0 - g) * attn_local
142
 
143
  else:
@@ -183,7 +187,9 @@ class Qwen3LoopForCausalLM(nn.Module):
183
  use_cache=None, output_attentions=None, output_hidden_states=None,
184
  return_dict=None, cache_position=None, **kwargs):
185
 
 
186
  if use_cache or (use_cache is None and self.config.use_cache and not self.training):
 
187
  for layer in self.model.layers:
188
  layer.self_attn._loop_mode = 0
189
  return self._forward_standard(
@@ -201,6 +207,7 @@ class Qwen3LoopForCausalLM(nn.Module):
201
  **kwargs
202
  )
203
 
 
204
  for layer in self.model.layers:
205
  layer.self_attn._loop_mode = 1
206
  with torch.no_grad():
@@ -214,6 +221,7 @@ class Qwen3LoopForCausalLM(nn.Module):
214
  **kwargs
215
  )
216
 
 
217
  for layer in self.model.layers:
218
  layer.self_attn._loop_mode = 2
219
  outputs = self._forward_standard(
@@ -230,6 +238,7 @@ class Qwen3LoopForCausalLM(nn.Module):
230
  **kwargs
231
  )
232
 
 
233
  for layer in self.model.layers:
234
  layer.self_attn._loop_mode = 0
235
  layer.self_attn._global_k = None
@@ -287,6 +296,7 @@ class Qwen3LoopForCausalLM(nn.Module):
287
 
288
  def generate(self, input_ids=None, **kwargs):
289
  """Generate text - always uses standard attention."""
 
290
  for layer in self.model.layers:
291
  layer.self_attn._loop_mode = 0
292
  layer.self_attn._global_k = None
@@ -338,7 +348,8 @@ class Qwen3LoopForCausalLM(nn.Module):
338
  def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
339
  attention_mask=None, inputs_embeds=None,
340
  cache_position=None, **kwargs):
341
- """Prepare inputs for generation step."""
 
342
  if past_key_values is not None:
343
  if inputs_embeds is not None:
344
  input_ids = input_ids[:, -cache_position.shape[0]:]
@@ -372,9 +383,42 @@ class Qwen3LoopForCausalLM(nn.Module):
372
  total = sum(p.numel() for p in self.parameters())
373
  print(f"Trainable: {trainable:,} / {total:,} ({100*trainable/total:.4f}%)")
374
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
375
  def get_gate_parameters(self):
376
- """Return list of gate parameters for optimizer."""
377
  params = []
378
  for layer in self.model.layers:
379
  params.extend(layer.self_attn.gate.parameters())
380
  return params
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
 
8
  class Qwen3LoopConfig:
9
+
10
  def __init__(self, base_config, loop_window_size=64):
11
  self.base_config = base_config
12
  self.loop_window_size = loop_window_size
 
14
  def __getattr__(self, name):
15
  return getattr(self.base_config, name)
16
 
17
+ # Learned Gate (With Fix for Init Shock)
18
 
19
  class LoopGate(nn.Module):
20
  def __init__(self, num_heads, head_dim):
 
22
  # Initialize weights to near-zero random noise to break symmetry
23
  self.weight = nn.Parameter(torch.randn(num_heads, head_dim) * 0.01)
24
 
25
+ # Initialize bias to +5.0
26
  # Sigmoid(5.0) ≈ 0.993
27
+ # This means the model starts with 99.3% Global Attention (Standard Qwen)
28
+ # and only 0.7% Local Attention. This prevents "garbage" output at step 0.
29
  self.bias = nn.Parameter(torch.full((num_heads,), 5.0))
30
 
31
  def forward(self, query_states):
 
35
 
36
 
37
 
38
+ # Loop Attention Layer
39
+
40
  class Qwen3LoopAttention(nn.Module):
41
  def __init__(self, original_attn: Qwen3Attention, loop_window_size: int = 64):
42
  super().__init__()
 
78
  cache_position=None, **kwargs):
79
  bsz, q_len, _ = hidden_states.size()
80
 
81
+ # Standard Projections
82
  query_states = self.q_proj(hidden_states)
83
  key_states = self.k_proj(hidden_states)
84
  value_states = self.v_proj(hidden_states)
 
103
  key_states_rpt = repeat_kv(key_states, self.num_key_value_groups)
104
  value_states_rpt = repeat_kv(value_states, self.num_key_value_groups)
105
 
 
106
  if self._loop_mode == 1:
107
  # Loop 1: Capture Global Context
108
  self._global_k = key_states_rpt.detach()
 
117
  # Loop 2: Mixed Attention
118
  g = self.gate(query_states)
119
 
120
+
121
  attn_global = F.scaled_dot_product_attention(
122
  query_states, self._global_k, self._global_v,
123
  attn_mask=attention_mask, is_causal=self.is_causal and attention_mask is None
124
  )
125
 
 
126
  ids_q = torch.arange(q_len, device=query_states.device).unsqueeze(1)
127
  ids_k = torch.arange(key_states.shape[2], device=query_states.device).unsqueeze(0)
128
  mask_window = (ids_k <= ids_q) & (ids_k > (ids_q - self.loop_window_size))
 
141
  attn_mask=local_mask, is_causal=False
142
  )
143
 
144
+ # Mixing: If Bias=5.0, g ~ 1.0, so result is mostly Global (Standard)
145
  attn_output = g * attn_global + (1.0 - g) * attn_local
146
 
147
  else:
 
187
  use_cache=None, output_attentions=None, output_hidden_states=None,
188
  return_dict=None, cache_position=None, **kwargs):
189
 
190
+ # If generating (use_cache=True), we disable the loop logic.
191
  if use_cache or (use_cache is None and self.config.use_cache and not self.training):
192
+ # Standard forward - bypass loop logic
193
  for layer in self.model.layers:
194
  layer.self_attn._loop_mode = 0
195
  return self._forward_standard(
 
207
  **kwargs
208
  )
209
 
210
+ # Loop 1: Capture Global
211
  for layer in self.model.layers:
212
  layer.self_attn._loop_mode = 1
213
  with torch.no_grad():
 
221
  **kwargs
222
  )
223
 
224
+ # Loop 2: Mix
225
  for layer in self.model.layers:
226
  layer.self_attn._loop_mode = 2
227
  outputs = self._forward_standard(
 
238
  **kwargs
239
  )
240
 
241
+ # Cleanup
242
  for layer in self.model.layers:
243
  layer.self_attn._loop_mode = 0
244
  layer.self_attn._global_k = None
 
296
 
297
  def generate(self, input_ids=None, **kwargs):
298
  """Generate text - always uses standard attention."""
299
+ # Ensure we use standard mode for generation
300
  for layer in self.model.layers:
301
  layer.self_attn._loop_mode = 0
302
  layer.self_attn._global_k = None
 
348
  def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
349
  attention_mask=None, inputs_embeds=None,
350
  cache_position=None, **kwargs):
351
+
352
+ # If we have past key values, only use last token
353
  if past_key_values is not None:
354
  if inputs_embeds is not None:
355
  input_ids = input_ids[:, -cache_position.shape[0]:]
 
383
  total = sum(p.numel() for p in self.parameters())
384
  print(f"Trainable: {trainable:,} / {total:,} ({100*trainable/total:.4f}%)")
385
 
386
+ def enable_gate_and_layernorm_training(self):
387
+ self.requires_grad_(False)
388
+
389
+ # Unfreeze gates
390
+ for layer in self.model.layers:
391
+ layer.self_attn.gate.requires_grad_(True)
392
+ # Unfreeze layer norms
393
+ layer.input_layernorm.requires_grad_(True)
394
+ layer.post_attention_layernorm.requires_grad_(True)
395
+ # Unfreeze Q/K norms in attention
396
+ layer.self_attn.q_norm.requires_grad_(True)
397
+ layer.self_attn.k_norm.requires_grad_(True)
398
+
399
+ # Unfreeze final layer norm
400
+ self.model.norm.requires_grad_(True)
401
+
402
+ trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
403
+ total = sum(p.numel() for p in self.parameters())
404
+ print(f"Trainable: {trainable:,} / {total:,} ({100*trainable/total:.4f}%)")
405
+
406
  def get_gate_parameters(self):
 
407
  params = []
408
  for layer in self.model.layers:
409
  params.extend(layer.self_attn.gate.parameters())
410
  return params
411
+
412
+ def get_trainable_parameters(self):
413
+ return [p for p in self.parameters() if p.requires_grad]
414
+
415
+ def save_pretrained(self, save_directory):
416
+ """Save the model weights and configuration."""
417
+ import os
418
+ os.makedirs(save_directory, exist_ok=True)
419
+
420
+ # Save config / added .bin compatability
421
+ self.config.save_pretrained(save_directory)
422
+
423
+ torch.save(self.state_dict(), os.path.join(save_directory, "qwen3looped.bin"))
424
+ print(f"Model saved to {save_directory}")