Ba2han commited on
Commit
fe99ebb
·
verified ·
1 Parent(s): 7f205a2

Upload 2 files

Browse files
Files changed (2) hide show
  1. configuration_patched.py +10 -0
  2. modeling_patched.py +70 -0
configuration_patched.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class PatchedQwen3Config(PretrainedConfig):
4
+ model_type = "patched_qwen3"
5
+
6
+ def __init__(self, **kwargs):
7
+ # Force the custom activation as the default
8
+ kwargs["hidden_act"] = kwargs.get("hidden_act", "squared_relu")
9
+ kwargs["rope_theta"] = kwargs.get("rope_theta", 1000000.0)
10
+ super().__init__(**kwargs)
modeling_patched.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from transformers import PreTrainedModel, AutoModelForCausalLM
5
+ from transformers.activations import ACT2FN
6
+ from .configuration_patched import PatchedQwen3Config
7
+
8
+ from transformers import Qwen3ForCausalLM
9
+
10
+
11
+ # ---------------------------------------------------------
12
+ # Define the Custom Components
13
+ # ---------------------------------------------------------
14
+ class SquaredReLU(nn.Module):
15
+ def forward(self, x):
16
+ return F.relu(x).square()
17
+
18
+ # Register it globally so the base architecture can find it
19
+ ACT2FN["squared_relu"] = SquaredReLU
20
+
21
+ class CappedLMHead(torch.nn.Module):
22
+ def __init__(self, base_head):
23
+ super().__init__()
24
+ self.base_head = base_head
25
+
26
+ def forward(self, x):
27
+ logits = self.base_head(x)
28
+ return 15 * logits * torch.rsqrt(logits.square() + 225)
29
+
30
+ class RMSNormLinear(torch.nn.Module):
31
+ def __init__(self, base_linear, head_dim):
32
+ super().__init__()
33
+ self.base_linear = base_linear
34
+ self.head_dim = head_dim
35
+
36
+ def forward(self, x):
37
+ y = self.base_linear(x)
38
+ shape = y.shape
39
+ y = y.view(*shape[:-1], -1, self.head_dim)
40
+ y = F.rms_norm(y, (self.head_dim,))
41
+ return y.view(*shape)
42
+
43
+ # ---------------------------------------------------------
44
+ # The Custom Auto-Patching Model
45
+ # ---------------------------------------------------------
46
+ class PatchedQwen3ForCausalLM(Qwen3ForCausalLM):
47
+ config_class = PatchedQwen3Config
48
+
49
+ def __init__(self, config):
50
+ # 1. Initialize the standard architecture
51
+ super().__init__(config)
52
+
53
+ # 2. Apply structural monkey-patches dynamically
54
+ self.lm_head = CappedLMHead(self.lm_head)
55
+
56
+ head_dim = config.head_dim
57
+
58
+ # Patch RoPE
59
+ inv_freq = (1 / 1024) ** torch.linspace(0, 1, steps=head_dim // 4, dtype=torch.float32)
60
+ inv_freq = torch.cat([inv_freq, inv_freq.new_zeros(head_dim // 4)])
61
+ self.model.rotary_emb.inv_freq = inv_freq
62
+
63
+ # Patch Q/K Projections
64
+ for layer in self.model.layers:
65
+ layer.self_attn.q_proj = RMSNormLinear(layer.self_attn.q_proj, head_dim)
66
+ layer.self_attn.k_proj = RMSNormLinear(layer.self_attn.k_proj, head_dim)
67
+
68
+ # Ensure weights are initialized if trained from scratch,
69
+ # but HF from_pretrained will overwrite these with your saved state_dict safely.
70
+ self.post_init()