| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| def swiglu(x: torch.Tensor) -> torch.Tensor: | |
| """ | |
| SwiGLU激活函数实现 | |
| SwiGLU是GLU(Gated Linear Unit)的一个变体,使用SiLU(也称为Swish)作为激活函数。 | |
| 公式: SwiGLU(x) = SiLU(a) ⊗ b,其中a和b是x沿最后一个维度分成的两部分 | |
| 参数: | |
| x: 输入张量,最后一个维度会被分成两半 | |
| 返回: | |
| 激活后的张量,维度是输入的一半 | |
| """ | |
| a, b = x.chunk(2, dim=-1) # 将输入沿最后一个维度切分成两部分 | |
| return F.silu(a) * b # 对a应用SiLU激活函数,然后与b逐元素相乘 | |
| def swiglu_pair(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: | |
| """ | |
| SwiGLU激活函数的成对版本 | |
| 直接接受已经分开的两个张量,而不是一个需要切分的张量 | |
| 参数: | |
| a: 第一个输入张量,将应用SiLU激活 | |
| b: 第二个输入张量,用于门控 | |
| 返回: | |
| 激活后的张量 | |
| """ | |
| return F.silu(a) * b # SiLU(a) * b | |
| class SwiGLU(nn.Module): | |
| """ | |
| SwiGLU激活模块 | |
| 将SwiGLU激活函数封装为一个PyTorch模块,方便在神经网络中使用 | |
| """ | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """ | |
| 前向传播 | |
| 参数: | |
| x: 输入张量 | |
| 返回: | |
| 激活后的张量 | |
| """ | |
| a, b = x.chunk(2, dim=-1) # 沿最后一个维度分成两半 | |
| return F.silu(a) * b # 应用SwiGLU激活 |