File size: 955 Bytes
0cfefd2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
"""SwiGLU 前馈网络。

实现:D -> Linear(2 * 4D) -> chunk2 -> SiLU(a) * b -> Linear(D)
即 D -> 4D -> SwiGLU -> 2D -> D,与 Design.md 规定一致。
"""

from __future__ import annotations

import torch
import torch.nn as nn
import torch.nn.functional as F


class SwiGLUFFN(nn.Module):
    """SwiGLU FFN: D->4D->SwiGLU->2D->D。

    使用 ``F.silu(a) * b`` 与现有 ``swiglu.py`` 中的实现一致。
    """

    def __init__(self, dim: int, mult: int = 4, dropout: float = 0.0, bias: bool = True) -> None:
        super().__init__()
        hidden = mult * dim
        self.fc1 = nn.Linear(dim, hidden * 2, bias=bias)  # 一次性投影出 a,b
        self.fc2 = nn.Linear(hidden, dim, bias=bias)
        self.drop = nn.Dropout(dropout) if dropout > 0 else nn.Identity()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        ab = self.fc1(x)
        a, b = ab.chunk(2, dim=-1)
        return self.drop(self.fc2(F.silu(a) * b))