Ba2han commited on
Commit
7e3324f
·
verified ·
1 Parent(s): 897b6eb

Upload patch.py

Browse files
Files changed (1) hide show
  1. patch.py +51 -0
patch.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from transformers.activations import ACT2FN
7
+
8
+ try:
9
+ from transformers.activations import ACT2CLS
10
+ except Exception:
11
+ ACT2CLS = None
12
+
13
+ from transformers.models.qwen3.modeling_qwen3 import Qwen3ForCausalLM as _Qwen3ForCausalLM
14
+
15
+
16
+ def squared_relu(x: torch.Tensor) -> torch.Tensor:
17
+ return torch.pow(F.relu(x), 2)
18
+
19
+
20
+ class SquaredReLUActivation(nn.Module):
21
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
22
+ return squared_relu(x)
23
+
24
+
25
+ def patch_transformers_squared_relu():
26
+ """
27
+ Register squared_relu for Qwen3 MLP loading.
28
+
29
+ Works with both newer Transformers ACT2FN ClassInstantier-style registries
30
+ and older plain callable registries.
31
+ """
32
+ raw_silu = ACT2FN.get("silu", None)
33
+
34
+ if ACT2CLS is not None:
35
+ ACT2CLS["squared_relu"] = SquaredReLUActivation
36
+
37
+ if isinstance(raw_silu, tuple):
38
+ ACT2FN["squared_relu"] = (SquaredReLUActivation, {})
39
+ elif isinstance(raw_silu, type) and issubclass(raw_silu, nn.Module):
40
+ ACT2FN["squared_relu"] = SquaredReLUActivation
41
+ else:
42
+ ACT2FN["squared_relu"] = squared_relu
43
+
44
+ return squared_relu
45
+
46
+
47
+ patch_transformers_squared_relu()
48
+
49
+
50
+ class SquaredReLUQwen3ForCausalLM(_Qwen3ForCausalLM):
51
+ pass