import torch import torch.nn as nn import torch.nn.functional as F from transformers import AutoModelForCausalLM, AutoConfig from transformers.models.qwen3.modeling_qwen3 import Qwen3Attention, apply_rotary_pos_emb, repeat_kv class Qwen3LoopConfig: def __init__(self, base_config, loop_window_size=64): self.base_config = base_config self.loop_window_size = loop_window_size def __getattr__(self, name): return getattr(self.base_config, name) # Learned Gate (With Fix for Init Shock) class LoopGate(nn.Module): def __init__(self, num_heads, head_dim): super().__init__() # Initialize weights to near-zero random noise to break symmetry self.weight = nn.Parameter(torch.randn(num_heads, head_dim) * 0.01) # Initialize bias to +5.0 # Sigmoid(5.0) ≈ 0.993 # This means the model starts with 99.3% Global Attention (Standard Qwen) # and only 0.7% Local Attention. This prevents "garbage" output at step 0. self.bias = nn.Parameter(torch.full((num_heads,), 5.0)) def forward(self, query_states): # [batch, heads, seq, dim] -> [batch, heads, seq, 1] gate_logits = torch.einsum('bhsd,hd->bhs', query_states, self.weight) + self.bias.view(1, -1, 1) return torch.sigmoid(gate_logits).unsqueeze(-1) # Loop Attention Layer class Qwen3LoopAttention(nn.Module): def __init__(self, original_attn: Qwen3Attention, loop_window_size: int = 64): super().__init__() self.loop_window_size = loop_window_size self.layer_idx = original_attn.layer_idx # Get config values config = original_attn.config self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = original_attn.head_dim self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = original_attn.num_key_value_groups self.scaling = original_attn.scaling self.is_causal = original_attn.is_causal # Qwen3 uses head_dim * num_heads which may differ from hidden_size self.attn_hidden_size = self.num_heads * self.head_dim # Share weights by reference (No extra memory) self.q_proj = original_attn.q_proj self.k_proj = original_attn.k_proj self.v_proj = original_attn.v_proj self.o_proj = original_attn.o_proj # Qwen3 specific: q_norm and k_norm self.q_norm = original_attn.q_norm self.k_norm = original_attn.k_norm # New Gate self.gate = LoopGate(self.num_heads, self.head_dim) # Loop State self._loop_mode = 0 self._global_k = None self._global_v = None def forward(self, hidden_states, position_embeddings, attention_mask=None, past_key_values=None, cache_position=None, **kwargs): bsz, q_len, _ = hidden_states.size() # Standard Projections query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) # Qwen3: Apply Q/K normalization query_states = self.q_norm(query_states) key_states = self.k_norm(key_states) # RoPE - Qwen3 passes position_embeddings from model level cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) # Update KV Cache if past_key_values is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states_rpt = repeat_kv(key_states, self.num_key_value_groups) value_states_rpt = repeat_kv(value_states, self.num_key_value_groups) if self._loop_mode == 1: # Loop 1: Capture Global Context self._global_k = key_states_rpt.detach() self._global_v = value_states_rpt.detach() attn_output = F.scaled_dot_product_attention( query_states, key_states_rpt, value_states_rpt, attn_mask=attention_mask, is_causal=self.is_causal and attention_mask is None ) elif self._loop_mode == 2: # Loop 2: Mixed Attention g = self.gate(query_states) attn_global = F.scaled_dot_product_attention( query_states, self._global_k, self._global_v, attn_mask=attention_mask, is_causal=self.is_causal and attention_mask is None ) ids_q = torch.arange(q_len, device=query_states.device).unsqueeze(1) ids_k = torch.arange(key_states.shape[2], device=query_states.device).unsqueeze(0) mask_window = (ids_k <= ids_q) & (ids_k > (ids_q - self.loop_window_size)) # Create local attention mask local_mask = torch.full( (1, 1, q_len, key_states.shape[2]), torch.finfo(query_states.dtype).min, device=query_states.device, dtype=query_states.dtype ) local_mask.masked_fill_(mask_window, 0.0) attn_local = F.scaled_dot_product_attention( query_states, key_states_rpt, value_states_rpt, attn_mask=local_mask, is_causal=False ) # Mixing: If Bias=5.0, g ~ 1.0, so result is mostly Global (Standard) attn_output = g * attn_global + (1.0 - g) * attn_local else: # Standard (for Inference/Generation fallback) attn_output = F.scaled_dot_product_attention( query_states, key_states_rpt, value_states_rpt, attn_mask=attention_mask, is_causal=self.is_causal and attention_mask is None ) attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, q_len, self.attn_hidden_size) attn_output = self.o_proj(attn_output) # Qwen3 expects (attn_output, attn_weights) return attn_output, None class Qwen3LoopForCausalLM(nn.Module): """Wrapper that adds Loop Attention to Qwen3.""" def __init__(self, base_model, loop_window_size=64): super().__init__() self.model = base_model.model self.lm_head = base_model.lm_head self.config = base_model.config self.loop_window_size = loop_window_size self.generation_config = base_model.generation_config # Replace attention layers with loop versions for layer in self.model.layers: if not isinstance(layer.self_attn, Qwen3LoopAttention): new_attn = Qwen3LoopAttention(layer.self_attn, loop_window_size) new_attn.to(layer.self_attn.q_proj.weight.device) new_attn.to(layer.self_attn.q_proj.weight.dtype) layer.self_attn = new_attn @classmethod def from_pretrained(cls, model_path, loop_window_size=64, **kwargs): base = AutoModelForCausalLM.from_pretrained(model_path, **kwargs) return cls(base, loop_window_size) def forward(self, input_ids=None, attention_mask=None, position_ids=None, past_key_values=None, inputs_embeds=None, labels=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, cache_position=None, **kwargs): # If generating (use_cache=True), we disable the loop logic. if use_cache or (use_cache is None and self.config.use_cache and not self.training): # Standard forward - bypass loop logic for layer in self.model.layers: layer.self_attn._loop_mode = 0 return self._forward_standard( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, **kwargs ) # Loop 1: Capture Global for layer in self.model.layers: layer.self_attn._loop_mode = 1 with torch.no_grad(): self._forward_standard( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=None, inputs_embeds=inputs_embeds, use_cache=False, **kwargs ) # Loop 2: Mix for layer in self.model.layers: layer.self_attn._loop_mode = 2 outputs = self._forward_standard( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=None, inputs_embeds=inputs_embeds, labels=labels, use_cache=False, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, **kwargs ) # Cleanup for layer in self.model.layers: layer.self_attn._loop_mode = 0 layer.self_attn._global_k = None layer.self_attn._global_v = None return outputs def _forward_standard(self, input_ids=None, attention_mask=None, position_ids=None, past_key_values=None, inputs_embeds=None, labels=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, cache_position=None, **kwargs): """Standard forward pass through the model.""" from transformers.modeling_outputs import CausalLMOutputWithPast return_dict = return_dict if return_dict is not None else self.config.use_return_dict # Get hidden states from model outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=True, cache_position=cache_position, ) hidden_states = outputs.last_hidden_state logits = self.lm_head(hidden_states) loss = None if labels is not None: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss = F.cross_entropy( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), ignore_index=-100 ) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def generate(self, input_ids=None, **kwargs): """Generate text - always uses standard attention.""" # Ensure we use standard mode for generation for layer in self.model.layers: layer.self_attn._loop_mode = 0 layer.self_attn._global_k = None layer.self_attn._global_v = None # Build a temporary wrapper that has the full generate() functionality # by using the base model architecture from transformers import AutoModelForCausalLM # Create a simple generation loop device = input_ids.device max_new_tokens = kwargs.get('max_new_tokens', 50) temperature = kwargs.get('temperature', 1.0) do_sample = kwargs.get('do_sample', False) top_p = kwargs.get('top_p', 1.0) pad_token_id = kwargs.get('pad_token_id', self.config.eos_token_id) eos_token_id = kwargs.get('eos_token_id', self.config.eos_token_id) generated = input_ids.clone() for _ in range(max_new_tokens): with torch.no_grad(): outputs = self(input_ids=generated, use_cache=True) next_token_logits = outputs.logits[:, -1, :] if do_sample and temperature > 0: next_token_logits = next_token_logits / temperature if top_p < 1.0: sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = False indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) next_token_logits[indices_to_remove] = float('-inf') probs = F.softmax(next_token_logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) else: next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True) generated = torch.cat([generated, next_token], dim=-1) if eos_token_id is not None and (next_token == eos_token_id).all(): break return generated def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, **kwargs): # If we have past key values, only use last token if past_key_values is not None: if inputs_embeds is not None: input_ids = input_ids[:, -cache_position.shape[0]:] elif input_ids.shape[1] != cache_position.shape[0]: input_ids = input_ids[:, cache_position] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: position_ids = position_ids[:, -input_ids.shape[1]:] model_inputs = { "input_ids": input_ids, "position_ids": position_ids, "cache_position": cache_position, "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache", True), "attention_mask": attention_mask, } return model_inputs def enable_gate_training_only(self): """Freeze all parameters except gates.""" self.requires_grad_(False) for layer in self.model.layers: layer.self_attn.gate.requires_grad_(True) trainable = sum(p.numel() for p in self.parameters() if p.requires_grad) total = sum(p.numel() for p in self.parameters()) print(f"Trainable: {trainable:,} / {total:,} ({100*trainable/total:.4f}%)") def enable_gate_and_layernorm_training(self): self.requires_grad_(False) # Unfreeze gates for layer in self.model.layers: layer.self_attn.gate.requires_grad_(True) # Unfreeze layer norms layer.input_layernorm.requires_grad_(True) layer.post_attention_layernorm.requires_grad_(True) # Unfreeze Q/K norms in attention layer.self_attn.q_norm.requires_grad_(True) layer.self_attn.k_norm.requires_grad_(True) # Unfreeze final layer norm self.model.norm.requires_grad_(True) trainable = sum(p.numel() for p in self.parameters() if p.requires_grad) total = sum(p.numel() for p in self.parameters()) print(f"Trainable: {trainable:,} / {total:,} ({100*trainable/total:.4f}%)") def get_gate_parameters(self): params = [] for layer in self.model.layers: params.extend(layer.self_attn.gate.parameters()) return params def get_trainable_parameters(self): return [p for p in self.parameters() if p.requires_grad] def save_pretrained(self, save_directory): """Save the model weights and configuration.""" import os os.makedirs(save_directory, exist_ok=True) # Save config / added .bin compatability self.config.save_pretrained(save_directory) torch.save(self.state_dict(), os.path.join(save_directory, "qwen3looped.bin")) print(f"Model saved to {save_directory}")