import torch import torch.nn as nn import torch.nn.functional as F def swiglu(x: torch.Tensor) -> torch.Tensor: """ SwiGLU激活函数实现 SwiGLU是GLU(Gated Linear Unit)的一个变体,使用SiLU(也称为Swish)作为激活函数。 公式: SwiGLU(x) = SiLU(a) ⊗ b,其中a和b是x沿最后一个维度分成的两部分 参数: x: 输入张量,最后一个维度会被分成两半 返回: 激活后的张量,维度是输入的一半 """ a, b = x.chunk(2, dim=-1) # 将输入沿最后一个维度切分成两部分 return F.silu(a) * b # 对a应用SiLU激活函数,然后与b逐元素相乘 def swiglu_pair(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: """ SwiGLU激活函数的成对版本 直接接受已经分开的两个张量,而不是一个需要切分的张量 参数: a: 第一个输入张量,将应用SiLU激活 b: 第二个输入张量,用于门控 返回: 激活后的张量 """ return F.silu(a) * b # SiLU(a) * b class SwiGLU(nn.Module): """ SwiGLU激活模块 将SwiGLU激活函数封装为一个PyTorch模块,方便在神经网络中使用 """ def forward(self, x: torch.Tensor) -> torch.Tensor: """ 前向传播 参数: x: 输入张量 返回: 激活后的张量 """ a, b = x.chunk(2, dim=-1) # 沿最后一个维度分成两半 return F.silu(a) * b # 应用SwiGLU激活