import torch import torch.nn as nn import torch.nn.functional as F from transformers.activations import ACT2FN try: from transformers.activations import ACT2CLS except Exception: ACT2CLS = None from transformers.models.qwen3.modeling_qwen3 import Qwen3ForCausalLM as _Qwen3ForCausalLM def squared_relu(x: torch.Tensor) -> torch.Tensor: return torch.pow(F.relu(x), 2) class SquaredReLUActivation(nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: return squared_relu(x) def patch_transformers_squared_relu(): """ Register squared_relu for Qwen3 MLP loading. Works with both newer Transformers ACT2FN ClassInstantier-style registries and older plain callable registries. """ raw_silu = ACT2FN.get("silu", None) if ACT2CLS is not None: ACT2CLS["squared_relu"] = SquaredReLUActivation if isinstance(raw_silu, tuple): ACT2FN["squared_relu"] = (SquaredReLUActivation, {}) elif isinstance(raw_silu, type) and issubclass(raw_silu, nn.Module): ACT2FN["squared_relu"] = SquaredReLUActivation else: ACT2FN["squared_relu"] = squared_relu return squared_relu patch_transformers_squared_relu() class SquaredReLUQwen3ForCausalLM(_Qwen3ForCausalLM): pass