| """SwiGLU 前馈网络。 |
| |
| 实现:D -> Linear(2 * 4D) -> chunk2 -> SiLU(a) * b -> Linear(D) |
| 即 D -> 4D -> SwiGLU -> 2D -> D,与 Design.md 规定一致。 |
| """ |
|
|
| from __future__ import annotations |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| class SwiGLUFFN(nn.Module): |
| """SwiGLU FFN: D->4D->SwiGLU->2D->D。 |
| |
| 使用 ``F.silu(a) * b`` 与现有 ``swiglu.py`` 中的实现一致。 |
| """ |
|
|
| def __init__(self, dim: int, mult: int = 4, dropout: float = 0.0, bias: bool = True) -> None: |
| super().__init__() |
| hidden = mult * dim |
| self.fc1 = nn.Linear(dim, hidden * 2, bias=bias) |
| self.fc2 = nn.Linear(hidden, dim, bias=bias) |
| self.drop = nn.Dropout(dropout) if dropout > 0 else nn.Identity() |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| ab = self.fc1(x) |
| a, b = ab.chunk(2, dim=-1) |
| return self.drop(self.fc2(F.silu(a) * b)) |
|
|