WJAD / src /wjad /modules /ffn.py
fuzirui's picture
Sync WJAD codebase
0cfefd2 verified
raw
history blame contribute delete
955 Bytes
"""SwiGLU 前馈网络。
实现:D -> Linear(2 * 4D) -> chunk2 -> SiLU(a) * b -> Linear(D)
即 D -> 4D -> SwiGLU -> 2D -> D,与 Design.md 规定一致。
"""
from __future__ import annotations
import torch
import torch.nn as nn
import torch.nn.functional as F
class SwiGLUFFN(nn.Module):
"""SwiGLU FFN: D->4D->SwiGLU->2D->D。
使用 ``F.silu(a) * b`` 与现有 ``swiglu.py`` 中的实现一致。
"""
def __init__(self, dim: int, mult: int = 4, dropout: float = 0.0, bias: bool = True) -> None:
super().__init__()
hidden = mult * dim
self.fc1 = nn.Linear(dim, hidden * 2, bias=bias) # 一次性投影出 a,b
self.fc2 = nn.Linear(hidden, dim, bias=bias)
self.drop = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor:
ab = self.fc1(x)
a, b = ab.chunk(2, dim=-1)
return self.drop(self.fc2(F.silu(a) * b))