|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from transformers import AutoModelForCausalLM, AutoConfig |
|
|
from transformers.models.qwen3.modeling_qwen3 import Qwen3Attention, apply_rotary_pos_emb, repeat_kv |
|
|
|
|
|
|
|
|
class Qwen3LoopConfig: |
|
|
|
|
|
def __init__(self, base_config, loop_window_size=64): |
|
|
self.base_config = base_config |
|
|
self.loop_window_size = loop_window_size |
|
|
|
|
|
def __getattr__(self, name): |
|
|
return getattr(self.base_config, name) |
|
|
|
|
|
|
|
|
|
|
|
class LoopGate(nn.Module): |
|
|
def __init__(self, num_heads, head_dim): |
|
|
super().__init__() |
|
|
|
|
|
self.weight = nn.Parameter(torch.randn(num_heads, head_dim) * 0.01) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.bias = nn.Parameter(torch.full((num_heads,), 5.0)) |
|
|
|
|
|
def forward(self, query_states): |
|
|
|
|
|
gate_logits = torch.einsum('bhsd,hd->bhs', query_states, self.weight) + self.bias.view(1, -1, 1) |
|
|
return torch.sigmoid(gate_logits).unsqueeze(-1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Qwen3LoopAttention(nn.Module): |
|
|
def __init__(self, original_attn: Qwen3Attention, loop_window_size: int = 64): |
|
|
super().__init__() |
|
|
self.loop_window_size = loop_window_size |
|
|
self.layer_idx = original_attn.layer_idx |
|
|
|
|
|
|
|
|
config = original_attn.config |
|
|
self.hidden_size = config.hidden_size |
|
|
self.num_heads = config.num_attention_heads |
|
|
self.head_dim = original_attn.head_dim |
|
|
self.num_key_value_heads = config.num_key_value_heads |
|
|
self.num_key_value_groups = original_attn.num_key_value_groups |
|
|
self.scaling = original_attn.scaling |
|
|
self.is_causal = original_attn.is_causal |
|
|
|
|
|
self.attn_hidden_size = self.num_heads * self.head_dim |
|
|
|
|
|
|
|
|
self.q_proj = original_attn.q_proj |
|
|
self.k_proj = original_attn.k_proj |
|
|
self.v_proj = original_attn.v_proj |
|
|
self.o_proj = original_attn.o_proj |
|
|
|
|
|
|
|
|
self.q_norm = original_attn.q_norm |
|
|
self.k_norm = original_attn.k_norm |
|
|
|
|
|
|
|
|
self.gate = LoopGate(self.num_heads, self.head_dim) |
|
|
|
|
|
|
|
|
self._loop_mode = 0 |
|
|
self._global_k = None |
|
|
self._global_v = None |
|
|
|
|
|
def forward(self, hidden_states, position_embeddings, |
|
|
attention_mask=None, past_key_values=None, |
|
|
cache_position=None, **kwargs): |
|
|
bsz, q_len, _ = hidden_states.size() |
|
|
|
|
|
|
|
|
query_states = self.q_proj(hidden_states) |
|
|
key_states = self.k_proj(hidden_states) |
|
|
value_states = self.v_proj(hidden_states) |
|
|
|
|
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
|
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
|
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
|
|
|
|
|
|
|
|
query_states = self.q_norm(query_states) |
|
|
key_states = self.k_norm(key_states) |
|
|
|
|
|
|
|
|
cos, sin = position_embeddings |
|
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) |
|
|
|
|
|
|
|
|
if past_key_values is not None: |
|
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} |
|
|
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) |
|
|
|
|
|
key_states_rpt = repeat_kv(key_states, self.num_key_value_groups) |
|
|
value_states_rpt = repeat_kv(value_states, self.num_key_value_groups) |
|
|
|
|
|
if self._loop_mode == 1: |
|
|
|
|
|
self._global_k = key_states_rpt.detach() |
|
|
self._global_v = value_states_rpt.detach() |
|
|
|
|
|
attn_output = F.scaled_dot_product_attention( |
|
|
query_states, key_states_rpt, value_states_rpt, |
|
|
attn_mask=attention_mask, is_causal=self.is_causal and attention_mask is None |
|
|
) |
|
|
|
|
|
elif self._loop_mode == 2: |
|
|
|
|
|
g = self.gate(query_states) |
|
|
|
|
|
|
|
|
attn_global = F.scaled_dot_product_attention( |
|
|
query_states, self._global_k, self._global_v, |
|
|
attn_mask=attention_mask, is_causal=self.is_causal and attention_mask is None |
|
|
) |
|
|
|
|
|
ids_q = torch.arange(q_len, device=query_states.device).unsqueeze(1) |
|
|
ids_k = torch.arange(key_states.shape[2], device=query_states.device).unsqueeze(0) |
|
|
mask_window = (ids_k <= ids_q) & (ids_k > (ids_q - self.loop_window_size)) |
|
|
|
|
|
|
|
|
local_mask = torch.full( |
|
|
(1, 1, q_len, key_states.shape[2]), |
|
|
torch.finfo(query_states.dtype).min, |
|
|
device=query_states.device, |
|
|
dtype=query_states.dtype |
|
|
) |
|
|
local_mask.masked_fill_(mask_window, 0.0) |
|
|
|
|
|
attn_local = F.scaled_dot_product_attention( |
|
|
query_states, key_states_rpt, value_states_rpt, |
|
|
attn_mask=local_mask, is_causal=False |
|
|
) |
|
|
|
|
|
|
|
|
attn_output = g * attn_global + (1.0 - g) * attn_local |
|
|
|
|
|
else: |
|
|
|
|
|
attn_output = F.scaled_dot_product_attention( |
|
|
query_states, key_states_rpt, value_states_rpt, |
|
|
attn_mask=attention_mask, is_causal=self.is_causal and attention_mask is None |
|
|
) |
|
|
|
|
|
attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, q_len, self.attn_hidden_size) |
|
|
attn_output = self.o_proj(attn_output) |
|
|
|
|
|
|
|
|
return attn_output, None |
|
|
|
|
|
|
|
|
class Qwen3LoopForCausalLM(nn.Module): |
|
|
"""Wrapper that adds Loop Attention to Qwen3.""" |
|
|
|
|
|
def __init__(self, base_model, loop_window_size=64): |
|
|
super().__init__() |
|
|
self.model = base_model.model |
|
|
self.lm_head = base_model.lm_head |
|
|
self.config = base_model.config |
|
|
self.loop_window_size = loop_window_size |
|
|
self.generation_config = base_model.generation_config |
|
|
|
|
|
|
|
|
for layer in self.model.layers: |
|
|
if not isinstance(layer.self_attn, Qwen3LoopAttention): |
|
|
new_attn = Qwen3LoopAttention(layer.self_attn, loop_window_size) |
|
|
new_attn.to(layer.self_attn.q_proj.weight.device) |
|
|
new_attn.to(layer.self_attn.q_proj.weight.dtype) |
|
|
layer.self_attn = new_attn |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained(cls, model_path, loop_window_size=64, **kwargs): |
|
|
base = AutoModelForCausalLM.from_pretrained(model_path, **kwargs) |
|
|
return cls(base, loop_window_size) |
|
|
|
|
|
def forward(self, input_ids=None, attention_mask=None, position_ids=None, |
|
|
past_key_values=None, inputs_embeds=None, labels=None, |
|
|
use_cache=None, output_attentions=None, output_hidden_states=None, |
|
|
return_dict=None, cache_position=None, **kwargs): |
|
|
|
|
|
|
|
|
if use_cache or (use_cache is None and self.config.use_cache and not self.training): |
|
|
|
|
|
for layer in self.model.layers: |
|
|
layer.self_attn._loop_mode = 0 |
|
|
return self._forward_standard( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
position_ids=position_ids, |
|
|
past_key_values=past_key_values, |
|
|
inputs_embeds=inputs_embeds, |
|
|
labels=labels, |
|
|
use_cache=use_cache, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict=return_dict, |
|
|
cache_position=cache_position, |
|
|
**kwargs |
|
|
) |
|
|
|
|
|
|
|
|
for layer in self.model.layers: |
|
|
layer.self_attn._loop_mode = 1 |
|
|
with torch.no_grad(): |
|
|
self._forward_standard( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
position_ids=position_ids, |
|
|
past_key_values=None, |
|
|
inputs_embeds=inputs_embeds, |
|
|
use_cache=False, |
|
|
**kwargs |
|
|
) |
|
|
|
|
|
|
|
|
for layer in self.model.layers: |
|
|
layer.self_attn._loop_mode = 2 |
|
|
outputs = self._forward_standard( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
position_ids=position_ids, |
|
|
past_key_values=None, |
|
|
inputs_embeds=inputs_embeds, |
|
|
labels=labels, |
|
|
use_cache=False, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict=return_dict, |
|
|
**kwargs |
|
|
) |
|
|
|
|
|
|
|
|
for layer in self.model.layers: |
|
|
layer.self_attn._loop_mode = 0 |
|
|
layer.self_attn._global_k = None |
|
|
layer.self_attn._global_v = None |
|
|
|
|
|
return outputs |
|
|
|
|
|
def _forward_standard(self, input_ids=None, attention_mask=None, position_ids=None, |
|
|
past_key_values=None, inputs_embeds=None, labels=None, |
|
|
use_cache=None, output_attentions=None, output_hidden_states=None, |
|
|
return_dict=None, cache_position=None, **kwargs): |
|
|
"""Standard forward pass through the model.""" |
|
|
from transformers.modeling_outputs import CausalLMOutputWithPast |
|
|
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
|
|
|
outputs = self.model( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
position_ids=position_ids, |
|
|
past_key_values=past_key_values, |
|
|
inputs_embeds=inputs_embeds, |
|
|
use_cache=use_cache, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict=True, |
|
|
cache_position=cache_position, |
|
|
) |
|
|
|
|
|
hidden_states = outputs.last_hidden_state |
|
|
logits = self.lm_head(hidden_states) |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
shift_logits = logits[..., :-1, :].contiguous() |
|
|
shift_labels = labels[..., 1:].contiguous() |
|
|
loss = F.cross_entropy( |
|
|
shift_logits.view(-1, shift_logits.size(-1)), |
|
|
shift_labels.view(-1), |
|
|
ignore_index=-100 |
|
|
) |
|
|
|
|
|
if not return_dict: |
|
|
output = (logits,) + outputs[1:] |
|
|
return (loss,) + output if loss is not None else output |
|
|
|
|
|
return CausalLMOutputWithPast( |
|
|
loss=loss, |
|
|
logits=logits, |
|
|
past_key_values=outputs.past_key_values, |
|
|
hidden_states=outputs.hidden_states, |
|
|
attentions=outputs.attentions, |
|
|
) |
|
|
|
|
|
def generate(self, input_ids=None, **kwargs): |
|
|
"""Generate text - always uses standard attention.""" |
|
|
|
|
|
for layer in self.model.layers: |
|
|
layer.self_attn._loop_mode = 0 |
|
|
layer.self_attn._global_k = None |
|
|
layer.self_attn._global_v = None |
|
|
|
|
|
|
|
|
|
|
|
from transformers import AutoModelForCausalLM |
|
|
|
|
|
|
|
|
device = input_ids.device |
|
|
max_new_tokens = kwargs.get('max_new_tokens', 50) |
|
|
temperature = kwargs.get('temperature', 1.0) |
|
|
do_sample = kwargs.get('do_sample', False) |
|
|
top_p = kwargs.get('top_p', 1.0) |
|
|
pad_token_id = kwargs.get('pad_token_id', self.config.eos_token_id) |
|
|
eos_token_id = kwargs.get('eos_token_id', self.config.eos_token_id) |
|
|
|
|
|
generated = input_ids.clone() |
|
|
|
|
|
for _ in range(max_new_tokens): |
|
|
with torch.no_grad(): |
|
|
outputs = self(input_ids=generated, use_cache=True) |
|
|
next_token_logits = outputs.logits[:, -1, :] |
|
|
|
|
|
if do_sample and temperature > 0: |
|
|
next_token_logits = next_token_logits / temperature |
|
|
if top_p < 1.0: |
|
|
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True) |
|
|
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
|
|
sorted_indices_to_remove = cumulative_probs > top_p |
|
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
|
|
sorted_indices_to_remove[..., 0] = False |
|
|
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) |
|
|
next_token_logits[indices_to_remove] = float('-inf') |
|
|
|
|
|
probs = F.softmax(next_token_logits, dim=-1) |
|
|
next_token = torch.multinomial(probs, num_samples=1) |
|
|
else: |
|
|
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True) |
|
|
|
|
|
generated = torch.cat([generated, next_token], dim=-1) |
|
|
|
|
|
if eos_token_id is not None and (next_token == eos_token_id).all(): |
|
|
break |
|
|
|
|
|
return generated |
|
|
|
|
|
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, |
|
|
attention_mask=None, inputs_embeds=None, |
|
|
cache_position=None, **kwargs): |
|
|
|
|
|
|
|
|
if past_key_values is not None: |
|
|
if inputs_embeds is not None: |
|
|
input_ids = input_ids[:, -cache_position.shape[0]:] |
|
|
elif input_ids.shape[1] != cache_position.shape[0]: |
|
|
input_ids = input_ids[:, cache_position] |
|
|
|
|
|
position_ids = kwargs.get("position_ids", None) |
|
|
if attention_mask is not None and position_ids is None: |
|
|
position_ids = attention_mask.long().cumsum(-1) - 1 |
|
|
position_ids.masked_fill_(attention_mask == 0, 1) |
|
|
if past_key_values: |
|
|
position_ids = position_ids[:, -input_ids.shape[1]:] |
|
|
|
|
|
model_inputs = { |
|
|
"input_ids": input_ids, |
|
|
"position_ids": position_ids, |
|
|
"cache_position": cache_position, |
|
|
"past_key_values": past_key_values, |
|
|
"use_cache": kwargs.get("use_cache", True), |
|
|
"attention_mask": attention_mask, |
|
|
} |
|
|
return model_inputs |
|
|
|
|
|
def enable_gate_training_only(self): |
|
|
"""Freeze all parameters except gates.""" |
|
|
self.requires_grad_(False) |
|
|
for layer in self.model.layers: |
|
|
layer.self_attn.gate.requires_grad_(True) |
|
|
|
|
|
trainable = sum(p.numel() for p in self.parameters() if p.requires_grad) |
|
|
total = sum(p.numel() for p in self.parameters()) |
|
|
print(f"Trainable: {trainable:,} / {total:,} ({100*trainable/total:.4f}%)") |
|
|
|
|
|
def enable_gate_and_layernorm_training(self): |
|
|
self.requires_grad_(False) |
|
|
|
|
|
|
|
|
for layer in self.model.layers: |
|
|
layer.self_attn.gate.requires_grad_(True) |
|
|
|
|
|
layer.input_layernorm.requires_grad_(True) |
|
|
layer.post_attention_layernorm.requires_grad_(True) |
|
|
|
|
|
layer.self_attn.q_norm.requires_grad_(True) |
|
|
layer.self_attn.k_norm.requires_grad_(True) |
|
|
|
|
|
|
|
|
self.model.norm.requires_grad_(True) |
|
|
|
|
|
trainable = sum(p.numel() for p in self.parameters() if p.requires_grad) |
|
|
total = sum(p.numel() for p in self.parameters()) |
|
|
print(f"Trainable: {trainable:,} / {total:,} ({100*trainable/total:.4f}%)") |
|
|
|
|
|
def get_gate_parameters(self): |
|
|
params = [] |
|
|
for layer in self.model.layers: |
|
|
params.extend(layer.self_attn.gate.parameters()) |
|
|
return params |
|
|
|
|
|
def get_trainable_parameters(self): |
|
|
return [p for p in self.parameters() if p.requires_grad] |
|
|
|
|
|
def save_pretrained(self, save_directory): |
|
|
"""Save the model weights and configuration.""" |
|
|
import os |
|
|
os.makedirs(save_directory, exist_ok=True) |
|
|
|
|
|
|
|
|
self.config.save_pretrained(save_directory) |
|
|
|
|
|
torch.save(self.state_dict(), os.path.join(save_directory, "qwen3looped.bin")) |
|
|
print(f"Model saved to {save_directory}") |
|
|
|