File size: 1,274 Bytes
7e3324f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52

import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers.activations import ACT2FN

try:
    from transformers.activations import ACT2CLS
except Exception:
    ACT2CLS = None

from transformers.models.qwen3.modeling_qwen3 import Qwen3ForCausalLM as _Qwen3ForCausalLM


def squared_relu(x: torch.Tensor) -> torch.Tensor:
    return torch.pow(F.relu(x), 2)


class SquaredReLUActivation(nn.Module):
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return squared_relu(x)


def patch_transformers_squared_relu():
    """
    Register squared_relu for Qwen3 MLP loading.

    Works with both newer Transformers ACT2FN ClassInstantier-style registries
    and older plain callable registries.
    """
    raw_silu = ACT2FN.get("silu", None)

    if ACT2CLS is not None:
        ACT2CLS["squared_relu"] = SquaredReLUActivation

    if isinstance(raw_silu, tuple):
        ACT2FN["squared_relu"] = (SquaredReLUActivation, {})
    elif isinstance(raw_silu, type) and issubclass(raw_silu, nn.Module):
        ACT2FN["squared_relu"] = SquaredReLUActivation
    else:
        ACT2FN["squared_relu"] = squared_relu

    return squared_relu


patch_transformers_squared_relu()


class SquaredReLUQwen3ForCausalLM(_Qwen3ForCausalLM):
    pass