Update modeling_qwen_loop.py
Browse filesMore comments yay!
Also can now save as .bin & .pt
- modeling_qwen_loop.py +52 -8
modeling_qwen_loop.py
CHANGED
|
@@ -6,6 +6,7 @@ from transformers.models.qwen3.modeling_qwen3 import Qwen3Attention, apply_rotar
|
|
| 6 |
|
| 7 |
|
| 8 |
class Qwen3LoopConfig:
|
|
|
|
| 9 |
def __init__(self, base_config, loop_window_size=64):
|
| 10 |
self.base_config = base_config
|
| 11 |
self.loop_window_size = loop_window_size
|
|
@@ -13,6 +14,7 @@ class Qwen3LoopConfig:
|
|
| 13 |
def __getattr__(self, name):
|
| 14 |
return getattr(self.base_config, name)
|
| 15 |
|
|
|
|
| 16 |
|
| 17 |
class LoopGate(nn.Module):
|
| 18 |
def __init__(self, num_heads, head_dim):
|
|
@@ -20,8 +22,10 @@ class LoopGate(nn.Module):
|
|
| 20 |
# Initialize weights to near-zero random noise to break symmetry
|
| 21 |
self.weight = nn.Parameter(torch.randn(num_heads, head_dim) * 0.01)
|
| 22 |
|
| 23 |
-
# Initialize bias to +5.0
|
| 24 |
# Sigmoid(5.0) ≈ 0.993
|
|
|
|
|
|
|
| 25 |
self.bias = nn.Parameter(torch.full((num_heads,), 5.0))
|
| 26 |
|
| 27 |
def forward(self, query_states):
|
|
@@ -31,7 +35,8 @@ class LoopGate(nn.Module):
|
|
| 31 |
|
| 32 |
|
| 33 |
|
| 34 |
-
# Loop Attention
|
|
|
|
| 35 |
class Qwen3LoopAttention(nn.Module):
|
| 36 |
def __init__(self, original_attn: Qwen3Attention, loop_window_size: int = 64):
|
| 37 |
super().__init__()
|
|
@@ -73,6 +78,7 @@ class Qwen3LoopAttention(nn.Module):
|
|
| 73 |
cache_position=None, **kwargs):
|
| 74 |
bsz, q_len, _ = hidden_states.size()
|
| 75 |
|
|
|
|
| 76 |
query_states = self.q_proj(hidden_states)
|
| 77 |
key_states = self.k_proj(hidden_states)
|
| 78 |
value_states = self.v_proj(hidden_states)
|
|
@@ -97,7 +103,6 @@ class Qwen3LoopAttention(nn.Module):
|
|
| 97 |
key_states_rpt = repeat_kv(key_states, self.num_key_value_groups)
|
| 98 |
value_states_rpt = repeat_kv(value_states, self.num_key_value_groups)
|
| 99 |
|
| 100 |
-
|
| 101 |
if self._loop_mode == 1:
|
| 102 |
# Loop 1: Capture Global Context
|
| 103 |
self._global_k = key_states_rpt.detach()
|
|
@@ -112,13 +117,12 @@ class Qwen3LoopAttention(nn.Module):
|
|
| 112 |
# Loop 2: Mixed Attention
|
| 113 |
g = self.gate(query_states)
|
| 114 |
|
| 115 |
-
|
| 116 |
attn_global = F.scaled_dot_product_attention(
|
| 117 |
query_states, self._global_k, self._global_v,
|
| 118 |
attn_mask=attention_mask, is_causal=self.is_causal and attention_mask is None
|
| 119 |
)
|
| 120 |
|
| 121 |
-
# Local (Windowed)
|
| 122 |
ids_q = torch.arange(q_len, device=query_states.device).unsqueeze(1)
|
| 123 |
ids_k = torch.arange(key_states.shape[2], device=query_states.device).unsqueeze(0)
|
| 124 |
mask_window = (ids_k <= ids_q) & (ids_k > (ids_q - self.loop_window_size))
|
|
@@ -137,7 +141,7 @@ class Qwen3LoopAttention(nn.Module):
|
|
| 137 |
attn_mask=local_mask, is_causal=False
|
| 138 |
)
|
| 139 |
|
| 140 |
-
# Mixing: If Bias=5.0, g ~ 1.0, so result is mostly
|
| 141 |
attn_output = g * attn_global + (1.0 - g) * attn_local
|
| 142 |
|
| 143 |
else:
|
|
@@ -183,7 +187,9 @@ class Qwen3LoopForCausalLM(nn.Module):
|
|
| 183 |
use_cache=None, output_attentions=None, output_hidden_states=None,
|
| 184 |
return_dict=None, cache_position=None, **kwargs):
|
| 185 |
|
|
|
|
| 186 |
if use_cache or (use_cache is None and self.config.use_cache and not self.training):
|
|
|
|
| 187 |
for layer in self.model.layers:
|
| 188 |
layer.self_attn._loop_mode = 0
|
| 189 |
return self._forward_standard(
|
|
@@ -201,6 +207,7 @@ class Qwen3LoopForCausalLM(nn.Module):
|
|
| 201 |
**kwargs
|
| 202 |
)
|
| 203 |
|
|
|
|
| 204 |
for layer in self.model.layers:
|
| 205 |
layer.self_attn._loop_mode = 1
|
| 206 |
with torch.no_grad():
|
|
@@ -214,6 +221,7 @@ class Qwen3LoopForCausalLM(nn.Module):
|
|
| 214 |
**kwargs
|
| 215 |
)
|
| 216 |
|
|
|
|
| 217 |
for layer in self.model.layers:
|
| 218 |
layer.self_attn._loop_mode = 2
|
| 219 |
outputs = self._forward_standard(
|
|
@@ -230,6 +238,7 @@ class Qwen3LoopForCausalLM(nn.Module):
|
|
| 230 |
**kwargs
|
| 231 |
)
|
| 232 |
|
|
|
|
| 233 |
for layer in self.model.layers:
|
| 234 |
layer.self_attn._loop_mode = 0
|
| 235 |
layer.self_attn._global_k = None
|
|
@@ -287,6 +296,7 @@ class Qwen3LoopForCausalLM(nn.Module):
|
|
| 287 |
|
| 288 |
def generate(self, input_ids=None, **kwargs):
|
| 289 |
"""Generate text - always uses standard attention."""
|
|
|
|
| 290 |
for layer in self.model.layers:
|
| 291 |
layer.self_attn._loop_mode = 0
|
| 292 |
layer.self_attn._global_k = None
|
|
@@ -338,7 +348,8 @@ class Qwen3LoopForCausalLM(nn.Module):
|
|
| 338 |
def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
|
| 339 |
attention_mask=None, inputs_embeds=None,
|
| 340 |
cache_position=None, **kwargs):
|
| 341 |
-
|
|
|
|
| 342 |
if past_key_values is not None:
|
| 343 |
if inputs_embeds is not None:
|
| 344 |
input_ids = input_ids[:, -cache_position.shape[0]:]
|
|
@@ -372,9 +383,42 @@ class Qwen3LoopForCausalLM(nn.Module):
|
|
| 372 |
total = sum(p.numel() for p in self.parameters())
|
| 373 |
print(f"Trainable: {trainable:,} / {total:,} ({100*trainable/total:.4f}%)")
|
| 374 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 375 |
def get_gate_parameters(self):
|
| 376 |
-
"""Return list of gate parameters for optimizer."""
|
| 377 |
params = []
|
| 378 |
for layer in self.model.layers:
|
| 379 |
params.extend(layer.self_attn.gate.parameters())
|
| 380 |
return params
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
|
| 8 |
class Qwen3LoopConfig:
|
| 9 |
+
|
| 10 |
def __init__(self, base_config, loop_window_size=64):
|
| 11 |
self.base_config = base_config
|
| 12 |
self.loop_window_size = loop_window_size
|
|
|
|
| 14 |
def __getattr__(self, name):
|
| 15 |
return getattr(self.base_config, name)
|
| 16 |
|
| 17 |
+
# Learned Gate (With Fix for Init Shock)
|
| 18 |
|
| 19 |
class LoopGate(nn.Module):
|
| 20 |
def __init__(self, num_heads, head_dim):
|
|
|
|
| 22 |
# Initialize weights to near-zero random noise to break symmetry
|
| 23 |
self.weight = nn.Parameter(torch.randn(num_heads, head_dim) * 0.01)
|
| 24 |
|
| 25 |
+
# Initialize bias to +5.0
|
| 26 |
# Sigmoid(5.0) ≈ 0.993
|
| 27 |
+
# This means the model starts with 99.3% Global Attention (Standard Qwen)
|
| 28 |
+
# and only 0.7% Local Attention. This prevents "garbage" output at step 0.
|
| 29 |
self.bias = nn.Parameter(torch.full((num_heads,), 5.0))
|
| 30 |
|
| 31 |
def forward(self, query_states):
|
|
|
|
| 35 |
|
| 36 |
|
| 37 |
|
| 38 |
+
# Loop Attention Layer
|
| 39 |
+
|
| 40 |
class Qwen3LoopAttention(nn.Module):
|
| 41 |
def __init__(self, original_attn: Qwen3Attention, loop_window_size: int = 64):
|
| 42 |
super().__init__()
|
|
|
|
| 78 |
cache_position=None, **kwargs):
|
| 79 |
bsz, q_len, _ = hidden_states.size()
|
| 80 |
|
| 81 |
+
# Standard Projections
|
| 82 |
query_states = self.q_proj(hidden_states)
|
| 83 |
key_states = self.k_proj(hidden_states)
|
| 84 |
value_states = self.v_proj(hidden_states)
|
|
|
|
| 103 |
key_states_rpt = repeat_kv(key_states, self.num_key_value_groups)
|
| 104 |
value_states_rpt = repeat_kv(value_states, self.num_key_value_groups)
|
| 105 |
|
|
|
|
| 106 |
if self._loop_mode == 1:
|
| 107 |
# Loop 1: Capture Global Context
|
| 108 |
self._global_k = key_states_rpt.detach()
|
|
|
|
| 117 |
# Loop 2: Mixed Attention
|
| 118 |
g = self.gate(query_states)
|
| 119 |
|
| 120 |
+
|
| 121 |
attn_global = F.scaled_dot_product_attention(
|
| 122 |
query_states, self._global_k, self._global_v,
|
| 123 |
attn_mask=attention_mask, is_causal=self.is_causal and attention_mask is None
|
| 124 |
)
|
| 125 |
|
|
|
|
| 126 |
ids_q = torch.arange(q_len, device=query_states.device).unsqueeze(1)
|
| 127 |
ids_k = torch.arange(key_states.shape[2], device=query_states.device).unsqueeze(0)
|
| 128 |
mask_window = (ids_k <= ids_q) & (ids_k > (ids_q - self.loop_window_size))
|
|
|
|
| 141 |
attn_mask=local_mask, is_causal=False
|
| 142 |
)
|
| 143 |
|
| 144 |
+
# Mixing: If Bias=5.0, g ~ 1.0, so result is mostly Global (Standard)
|
| 145 |
attn_output = g * attn_global + (1.0 - g) * attn_local
|
| 146 |
|
| 147 |
else:
|
|
|
|
| 187 |
use_cache=None, output_attentions=None, output_hidden_states=None,
|
| 188 |
return_dict=None, cache_position=None, **kwargs):
|
| 189 |
|
| 190 |
+
# If generating (use_cache=True), we disable the loop logic.
|
| 191 |
if use_cache or (use_cache is None and self.config.use_cache and not self.training):
|
| 192 |
+
# Standard forward - bypass loop logic
|
| 193 |
for layer in self.model.layers:
|
| 194 |
layer.self_attn._loop_mode = 0
|
| 195 |
return self._forward_standard(
|
|
|
|
| 207 |
**kwargs
|
| 208 |
)
|
| 209 |
|
| 210 |
+
# Loop 1: Capture Global
|
| 211 |
for layer in self.model.layers:
|
| 212 |
layer.self_attn._loop_mode = 1
|
| 213 |
with torch.no_grad():
|
|
|
|
| 221 |
**kwargs
|
| 222 |
)
|
| 223 |
|
| 224 |
+
# Loop 2: Mix
|
| 225 |
for layer in self.model.layers:
|
| 226 |
layer.self_attn._loop_mode = 2
|
| 227 |
outputs = self._forward_standard(
|
|
|
|
| 238 |
**kwargs
|
| 239 |
)
|
| 240 |
|
| 241 |
+
# Cleanup
|
| 242 |
for layer in self.model.layers:
|
| 243 |
layer.self_attn._loop_mode = 0
|
| 244 |
layer.self_attn._global_k = None
|
|
|
|
| 296 |
|
| 297 |
def generate(self, input_ids=None, **kwargs):
|
| 298 |
"""Generate text - always uses standard attention."""
|
| 299 |
+
# Ensure we use standard mode for generation
|
| 300 |
for layer in self.model.layers:
|
| 301 |
layer.self_attn._loop_mode = 0
|
| 302 |
layer.self_attn._global_k = None
|
|
|
|
| 348 |
def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
|
| 349 |
attention_mask=None, inputs_embeds=None,
|
| 350 |
cache_position=None, **kwargs):
|
| 351 |
+
|
| 352 |
+
# If we have past key values, only use last token
|
| 353 |
if past_key_values is not None:
|
| 354 |
if inputs_embeds is not None:
|
| 355 |
input_ids = input_ids[:, -cache_position.shape[0]:]
|
|
|
|
| 383 |
total = sum(p.numel() for p in self.parameters())
|
| 384 |
print(f"Trainable: {trainable:,} / {total:,} ({100*trainable/total:.4f}%)")
|
| 385 |
|
| 386 |
+
def enable_gate_and_layernorm_training(self):
|
| 387 |
+
self.requires_grad_(False)
|
| 388 |
+
|
| 389 |
+
# Unfreeze gates
|
| 390 |
+
for layer in self.model.layers:
|
| 391 |
+
layer.self_attn.gate.requires_grad_(True)
|
| 392 |
+
# Unfreeze layer norms
|
| 393 |
+
layer.input_layernorm.requires_grad_(True)
|
| 394 |
+
layer.post_attention_layernorm.requires_grad_(True)
|
| 395 |
+
# Unfreeze Q/K norms in attention
|
| 396 |
+
layer.self_attn.q_norm.requires_grad_(True)
|
| 397 |
+
layer.self_attn.k_norm.requires_grad_(True)
|
| 398 |
+
|
| 399 |
+
# Unfreeze final layer norm
|
| 400 |
+
self.model.norm.requires_grad_(True)
|
| 401 |
+
|
| 402 |
+
trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
|
| 403 |
+
total = sum(p.numel() for p in self.parameters())
|
| 404 |
+
print(f"Trainable: {trainable:,} / {total:,} ({100*trainable/total:.4f}%)")
|
| 405 |
+
|
| 406 |
def get_gate_parameters(self):
|
|
|
|
| 407 |
params = []
|
| 408 |
for layer in self.model.layers:
|
| 409 |
params.extend(layer.self_attn.gate.parameters())
|
| 410 |
return params
|
| 411 |
+
|
| 412 |
+
def get_trainable_parameters(self):
|
| 413 |
+
return [p for p in self.parameters() if p.requires_grad]
|
| 414 |
+
|
| 415 |
+
def save_pretrained(self, save_directory):
|
| 416 |
+
"""Save the model weights and configuration."""
|
| 417 |
+
import os
|
| 418 |
+
os.makedirs(save_directory, exist_ok=True)
|
| 419 |
+
|
| 420 |
+
# Save config / added .bin compatability
|
| 421 |
+
self.config.save_pretrained(save_directory)
|
| 422 |
+
|
| 423 |
+
torch.save(self.state_dict(), os.path.join(save_directory, "qwen3looped.bin"))
|
| 424 |
+
print(f"Model saved to {save_directory}")
|