"""SwiGLU 前馈网络。 实现:D -> Linear(2 * 4D) -> chunk2 -> SiLU(a) * b -> Linear(D) 即 D -> 4D -> SwiGLU -> 2D -> D,与 Design.md 规定一致。 """ from __future__ import annotations import torch import torch.nn as nn import torch.nn.functional as F class SwiGLUFFN(nn.Module): """SwiGLU FFN: D->4D->SwiGLU->2D->D。 使用 ``F.silu(a) * b`` 与现有 ``swiglu.py`` 中的实现一致。 """ def __init__(self, dim: int, mult: int = 4, dropout: float = 0.0, bias: bool = True) -> None: super().__init__() hidden = mult * dim self.fc1 = nn.Linear(dim, hidden * 2, bias=bias) # 一次性投影出 a,b self.fc2 = nn.Linear(hidden, dim, bias=bias) self.drop = nn.Dropout(dropout) if dropout > 0 else nn.Identity() def forward(self, x: torch.Tensor) -> torch.Tensor: ab = self.fc1(x) a, b = ab.chunk(2, dim=-1) return self.drop(self.fc2(F.silu(a) * b))