experimental_auto / patch.py
Ba2han's picture
Upload patch.py
7e3324f verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.activations import ACT2FN
try:
from transformers.activations import ACT2CLS
except Exception:
ACT2CLS = None
from transformers.models.qwen3.modeling_qwen3 import Qwen3ForCausalLM as _Qwen3ForCausalLM
def squared_relu(x: torch.Tensor) -> torch.Tensor:
return torch.pow(F.relu(x), 2)
class SquaredReLUActivation(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return squared_relu(x)
def patch_transformers_squared_relu():
"""
Register squared_relu for Qwen3 MLP loading.
Works with both newer Transformers ACT2FN ClassInstantier-style registries
and older plain callable registries.
"""
raw_silu = ACT2FN.get("silu", None)
if ACT2CLS is not None:
ACT2CLS["squared_relu"] = SquaredReLUActivation
if isinstance(raw_silu, tuple):
ACT2FN["squared_relu"] = (SquaredReLUActivation, {})
elif isinstance(raw_silu, type) and issubclass(raw_silu, nn.Module):
ACT2FN["squared_relu"] = SquaredReLUActivation
else:
ACT2FN["squared_relu"] = squared_relu
return squared_relu
patch_transformers_squared_relu()
class SquaredReLUQwen3ForCausalLM(_Qwen3ForCausalLM):
pass