test-pt-2 / modeling_patched.py
Ba2han's picture
Upload 2 files
fe99ebb verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel, AutoModelForCausalLM
from transformers.activations import ACT2FN
from .configuration_patched import PatchedQwen3Config
from transformers import Qwen3ForCausalLM
# ---------------------------------------------------------
# Define the Custom Components
# ---------------------------------------------------------
class SquaredReLU(nn.Module):
def forward(self, x):
return F.relu(x).square()
# Register it globally so the base architecture can find it
ACT2FN["squared_relu"] = SquaredReLU
class CappedLMHead(torch.nn.Module):
def __init__(self, base_head):
super().__init__()
self.base_head = base_head
def forward(self, x):
logits = self.base_head(x)
return 15 * logits * torch.rsqrt(logits.square() + 225)
class RMSNormLinear(torch.nn.Module):
def __init__(self, base_linear, head_dim):
super().__init__()
self.base_linear = base_linear
self.head_dim = head_dim
def forward(self, x):
y = self.base_linear(x)
shape = y.shape
y = y.view(*shape[:-1], -1, self.head_dim)
y = F.rms_norm(y, (self.head_dim,))
return y.view(*shape)
# ---------------------------------------------------------
# The Custom Auto-Patching Model
# ---------------------------------------------------------
class PatchedQwen3ForCausalLM(Qwen3ForCausalLM):
config_class = PatchedQwen3Config
def __init__(self, config):
# 1. Initialize the standard architecture
super().__init__(config)
# 2. Apply structural monkey-patches dynamically
self.lm_head = CappedLMHead(self.lm_head)
head_dim = config.head_dim
# Patch RoPE
inv_freq = (1 / 1024) ** torch.linspace(0, 1, steps=head_dim // 4, dtype=torch.float32)
inv_freq = torch.cat([inv_freq, inv_freq.new_zeros(head_dim // 4)])
self.model.rotary_emb.inv_freq = inv_freq
# Patch Q/K Projections
for layer in self.model.layers:
layer.self_attn.q_proj = RMSNormLinear(layer.self_attn.q_proj, head_dim)
layer.self_attn.k_proj = RMSNormLinear(layer.self_attn.k_proj, head_dim)
# Ensure weights are initialized if trained from scratch,
# but HF from_pretrained will overwrite these with your saved state_dict safely.
self.post_init()