Qwen3-0.6B-Looped / modeling_qwen_loop.py
coolpoodle's picture
Update modeling_qwen_loop.py
29fdd49 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoConfig
from transformers.models.qwen3.modeling_qwen3 import Qwen3Attention, apply_rotary_pos_emb, repeat_kv
class Qwen3LoopConfig:
def __init__(self, base_config, loop_window_size=64):
self.base_config = base_config
self.loop_window_size = loop_window_size
def __getattr__(self, name):
return getattr(self.base_config, name)
# Learned Gate (With Fix for Init Shock)
class LoopGate(nn.Module):
def __init__(self, num_heads, head_dim):
super().__init__()
# Initialize weights to near-zero random noise to break symmetry
self.weight = nn.Parameter(torch.randn(num_heads, head_dim) * 0.01)
# Initialize bias to +5.0
# Sigmoid(5.0) ≈ 0.993
# This means the model starts with 99.3% Global Attention (Standard Qwen)
# and only 0.7% Local Attention. This prevents "garbage" output at step 0.
self.bias = nn.Parameter(torch.full((num_heads,), 5.0))
def forward(self, query_states):
# [batch, heads, seq, dim] -> [batch, heads, seq, 1]
gate_logits = torch.einsum('bhsd,hd->bhs', query_states, self.weight) + self.bias.view(1, -1, 1)
return torch.sigmoid(gate_logits).unsqueeze(-1)
# Loop Attention Layer
class Qwen3LoopAttention(nn.Module):
def __init__(self, original_attn: Qwen3Attention, loop_window_size: int = 64):
super().__init__()
self.loop_window_size = loop_window_size
self.layer_idx = original_attn.layer_idx
# Get config values
config = original_attn.config
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = original_attn.head_dim
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = original_attn.num_key_value_groups
self.scaling = original_attn.scaling
self.is_causal = original_attn.is_causal
# Qwen3 uses head_dim * num_heads which may differ from hidden_size
self.attn_hidden_size = self.num_heads * self.head_dim
# Share weights by reference (No extra memory)
self.q_proj = original_attn.q_proj
self.k_proj = original_attn.k_proj
self.v_proj = original_attn.v_proj
self.o_proj = original_attn.o_proj
# Qwen3 specific: q_norm and k_norm
self.q_norm = original_attn.q_norm
self.k_norm = original_attn.k_norm
# New Gate
self.gate = LoopGate(self.num_heads, self.head_dim)
# Loop State
self._loop_mode = 0
self._global_k = None
self._global_v = None
def forward(self, hidden_states, position_embeddings,
attention_mask=None, past_key_values=None,
cache_position=None, **kwargs):
bsz, q_len, _ = hidden_states.size()
# Standard Projections
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
# Qwen3: Apply Q/K normalization
query_states = self.q_norm(query_states)
key_states = self.k_norm(key_states)
# RoPE - Qwen3 passes position_embeddings from model level
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
# Update KV Cache
if past_key_values is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states_rpt = repeat_kv(key_states, self.num_key_value_groups)
value_states_rpt = repeat_kv(value_states, self.num_key_value_groups)
if self._loop_mode == 1:
# Loop 1: Capture Global Context
self._global_k = key_states_rpt.detach()
self._global_v = value_states_rpt.detach()
attn_output = F.scaled_dot_product_attention(
query_states, key_states_rpt, value_states_rpt,
attn_mask=attention_mask, is_causal=self.is_causal and attention_mask is None
)
elif self._loop_mode == 2:
# Loop 2: Mixed Attention
g = self.gate(query_states)
attn_global = F.scaled_dot_product_attention(
query_states, self._global_k, self._global_v,
attn_mask=attention_mask, is_causal=self.is_causal and attention_mask is None
)
ids_q = torch.arange(q_len, device=query_states.device).unsqueeze(1)
ids_k = torch.arange(key_states.shape[2], device=query_states.device).unsqueeze(0)
mask_window = (ids_k <= ids_q) & (ids_k > (ids_q - self.loop_window_size))
# Create local attention mask
local_mask = torch.full(
(1, 1, q_len, key_states.shape[2]),
torch.finfo(query_states.dtype).min,
device=query_states.device,
dtype=query_states.dtype
)
local_mask.masked_fill_(mask_window, 0.0)
attn_local = F.scaled_dot_product_attention(
query_states, key_states_rpt, value_states_rpt,
attn_mask=local_mask, is_causal=False
)
# Mixing: If Bias=5.0, g ~ 1.0, so result is mostly Global (Standard)
attn_output = g * attn_global + (1.0 - g) * attn_local
else:
# Standard (for Inference/Generation fallback)
attn_output = F.scaled_dot_product_attention(
query_states, key_states_rpt, value_states_rpt,
attn_mask=attention_mask, is_causal=self.is_causal and attention_mask is None
)
attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, q_len, self.attn_hidden_size)
attn_output = self.o_proj(attn_output)
# Qwen3 expects (attn_output, attn_weights)
return attn_output, None
class Qwen3LoopForCausalLM(nn.Module):
"""Wrapper that adds Loop Attention to Qwen3."""
def __init__(self, base_model, loop_window_size=64):
super().__init__()
self.model = base_model.model
self.lm_head = base_model.lm_head
self.config = base_model.config
self.loop_window_size = loop_window_size
self.generation_config = base_model.generation_config
# Replace attention layers with loop versions
for layer in self.model.layers:
if not isinstance(layer.self_attn, Qwen3LoopAttention):
new_attn = Qwen3LoopAttention(layer.self_attn, loop_window_size)
new_attn.to(layer.self_attn.q_proj.weight.device)
new_attn.to(layer.self_attn.q_proj.weight.dtype)
layer.self_attn = new_attn
@classmethod
def from_pretrained(cls, model_path, loop_window_size=64, **kwargs):
base = AutoModelForCausalLM.from_pretrained(model_path, **kwargs)
return cls(base, loop_window_size)
def forward(self, input_ids=None, attention_mask=None, position_ids=None,
past_key_values=None, inputs_embeds=None, labels=None,
use_cache=None, output_attentions=None, output_hidden_states=None,
return_dict=None, cache_position=None, **kwargs):
# If generating (use_cache=True), we disable the loop logic.
if use_cache or (use_cache is None and self.config.use_cache and not self.training):
# Standard forward - bypass loop logic
for layer in self.model.layers:
layer.self_attn._loop_mode = 0
return self._forward_standard(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs
)
# Loop 1: Capture Global
for layer in self.model.layers:
layer.self_attn._loop_mode = 1
with torch.no_grad():
self._forward_standard(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=None,
inputs_embeds=inputs_embeds,
use_cache=False,
**kwargs
)
# Loop 2: Mix
for layer in self.model.layers:
layer.self_attn._loop_mode = 2
outputs = self._forward_standard(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=None,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=False,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs
)
# Cleanup
for layer in self.model.layers:
layer.self_attn._loop_mode = 0
layer.self_attn._global_k = None
layer.self_attn._global_v = None
return outputs
def _forward_standard(self, input_ids=None, attention_mask=None, position_ids=None,
past_key_values=None, inputs_embeds=None, labels=None,
use_cache=None, output_attentions=None, output_hidden_states=None,
return_dict=None, cache_position=None, **kwargs):
"""Standard forward pass through the model."""
from transformers.modeling_outputs import CausalLMOutputWithPast
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Get hidden states from model
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
)
hidden_states = outputs.last_hidden_state
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = F.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
ignore_index=-100
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def generate(self, input_ids=None, **kwargs):
"""Generate text - always uses standard attention."""
# Ensure we use standard mode for generation
for layer in self.model.layers:
layer.self_attn._loop_mode = 0
layer.self_attn._global_k = None
layer.self_attn._global_v = None
# Build a temporary wrapper that has the full generate() functionality
# by using the base model architecture
from transformers import AutoModelForCausalLM
# Create a simple generation loop
device = input_ids.device
max_new_tokens = kwargs.get('max_new_tokens', 50)
temperature = kwargs.get('temperature', 1.0)
do_sample = kwargs.get('do_sample', False)
top_p = kwargs.get('top_p', 1.0)
pad_token_id = kwargs.get('pad_token_id', self.config.eos_token_id)
eos_token_id = kwargs.get('eos_token_id', self.config.eos_token_id)
generated = input_ids.clone()
for _ in range(max_new_tokens):
with torch.no_grad():
outputs = self(input_ids=generated, use_cache=True)
next_token_logits = outputs.logits[:, -1, :]
if do_sample and temperature > 0:
next_token_logits = next_token_logits / temperature
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = False
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
next_token_logits[indices_to_remove] = float('-inf')
probs = F.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
else:
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
generated = torch.cat([generated, next_token], dim=-1)
if eos_token_id is not None and (next_token == eos_token_id).all():
break
return generated
def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
attention_mask=None, inputs_embeds=None,
cache_position=None, **kwargs):
# If we have past key values, only use last token
if past_key_values is not None:
if inputs_embeds is not None:
input_ids = input_ids[:, -cache_position.shape[0]:]
elif input_ids.shape[1] != cache_position.shape[0]:
input_ids = input_ids[:, cache_position]
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1]:]
model_inputs = {
"input_ids": input_ids,
"position_ids": position_ids,
"cache_position": cache_position,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache", True),
"attention_mask": attention_mask,
}
return model_inputs
def enable_gate_training_only(self):
"""Freeze all parameters except gates."""
self.requires_grad_(False)
for layer in self.model.layers:
layer.self_attn.gate.requires_grad_(True)
trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
total = sum(p.numel() for p in self.parameters())
print(f"Trainable: {trainable:,} / {total:,} ({100*trainable/total:.4f}%)")
def enable_gate_and_layernorm_training(self):
self.requires_grad_(False)
# Unfreeze gates
for layer in self.model.layers:
layer.self_attn.gate.requires_grad_(True)
# Unfreeze layer norms
layer.input_layernorm.requires_grad_(True)
layer.post_attention_layernorm.requires_grad_(True)
# Unfreeze Q/K norms in attention
layer.self_attn.q_norm.requires_grad_(True)
layer.self_attn.k_norm.requires_grad_(True)
# Unfreeze final layer norm
self.model.norm.requires_grad_(True)
trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
total = sum(p.numel() for p in self.parameters())
print(f"Trainable: {trainable:,} / {total:,} ({100*trainable/total:.4f}%)")
def get_gate_parameters(self):
params = []
for layer in self.model.layers:
params.extend(layer.self_attn.gate.parameters())
return params
def get_trainable_parameters(self):
return [p for p in self.parameters() if p.requires_grad]
def save_pretrained(self, save_directory):
"""Save the model weights and configuration."""
import os
os.makedirs(save_directory, exist_ok=True)
# Save config / added .bin compatability
self.config.save_pretrained(save_directory)
torch.save(self.state_dict(), os.path.join(save_directory, "qwen3looped.bin"))
print(f"Model saved to {save_directory}")