import torch import torch.nn as nn import torch.nn.functional as F from transformers import PreTrainedModel, AutoModelForCausalLM from transformers.activations import ACT2FN from .configuration_patched import PatchedQwen3Config from transformers import Qwen3ForCausalLM # --------------------------------------------------------- # Define the Custom Components # --------------------------------------------------------- class SquaredReLU(nn.Module): def forward(self, x): return F.relu(x).square() # Register it globally so the base architecture can find it ACT2FN["squared_relu"] = SquaredReLU class CappedLMHead(torch.nn.Module): def __init__(self, base_head): super().__init__() self.base_head = base_head def forward(self, x): logits = self.base_head(x) return 15 * logits * torch.rsqrt(logits.square() + 225) class RMSNormLinear(torch.nn.Module): def __init__(self, base_linear, head_dim): super().__init__() self.base_linear = base_linear self.head_dim = head_dim def forward(self, x): y = self.base_linear(x) shape = y.shape y = y.view(*shape[:-1], -1, self.head_dim) y = F.rms_norm(y, (self.head_dim,)) return y.view(*shape) # --------------------------------------------------------- # The Custom Auto-Patching Model # --------------------------------------------------------- class PatchedQwen3ForCausalLM(Qwen3ForCausalLM): config_class = PatchedQwen3Config def __init__(self, config): # 1. Initialize the standard architecture super().__init__(config) # 2. Apply structural monkey-patches dynamically self.lm_head = CappedLMHead(self.lm_head) head_dim = config.head_dim # Patch RoPE inv_freq = (1 / 1024) ** torch.linspace(0, 1, steps=head_dim // 4, dtype=torch.float32) inv_freq = torch.cat([inv_freq, inv_freq.new_zeros(head_dim // 4)]) self.model.rotary_emb.inv_freq = inv_freq # Patch Q/K Projections for layer in self.model.layers: layer.self_attn.q_proj = RMSNormLinear(layer.self_attn.q_proj, head_dim) layer.self_attn.k_proj = RMSNormLinear(layer.self_attn.k_proj, head_dim) # Ensure weights are initialized if trained from scratch, # but HF from_pretrained will overwrite these with your saved state_dict safely. self.post_init()