WJAD / swiglu.py
fuzirui's picture
Upload folder using huggingface_hub
2f30c49 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
def swiglu(x: torch.Tensor) -> torch.Tensor:
"""
SwiGLU激活函数实现
SwiGLU是GLU(Gated Linear Unit)的一个变体,使用SiLU(也称为Swish)作为激活函数。
公式: SwiGLU(x) = SiLU(a) ⊗ b,其中a和b是x沿最后一个维度分成的两部分
参数:
x: 输入张量,最后一个维度会被分成两半
返回:
激活后的张量,维度是输入的一半
"""
a, b = x.chunk(2, dim=-1) # 将输入沿最后一个维度切分成两部分
return F.silu(a) * b # 对a应用SiLU激活函数,然后与b逐元素相乘
def swiglu_pair(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
"""
SwiGLU激活函数的成对版本
直接接受已经分开的两个张量,而不是一个需要切分的张量
参数:
a: 第一个输入张量,将应用SiLU激活
b: 第二个输入张量,用于门控
返回:
激活后的张量
"""
return F.silu(a) * b # SiLU(a) * b
class SwiGLU(nn.Module):
"""
SwiGLU激活模块
将SwiGLU激活函数封装为一个PyTorch模块,方便在神经网络中使用
"""
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
前向传播
参数:
x: 输入张量
返回:
激活后的张量
"""
a, b = x.chunk(2, dim=-1) # 沿最后一个维度分成两半
return F.silu(a) * b # 应用SwiGLU激活