File size: 873 Bytes
714cf46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import torch
import torch.nn as nn
import torch.nn.functional as F


def intermediate_correction_fn(expansion_ratio: float, hidden_size: int) -> int:
    return int(((expansion_ratio * hidden_size) + 255) // 256 * 256)


class SwiGLU(nn.Module):
    def __init__(self):
        super(SwiGLU, self).__init__()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x1, x2 = x.chunk(2, dim=-1)
        return F.silu(x1) * x2


def swiglu_ln_ffn(hidden_size: int, expansion_ratio: float, dropout: float = 0.1, use_bias: bool = False):
    intermediate_size = intermediate_correction_fn(expansion_ratio, hidden_size)
    return nn.Sequential(
        nn.LayerNorm(hidden_size),
        nn.Linear(hidden_size, intermediate_size * 2, bias=use_bias),
        SwiGLU(),
        nn.Dropout(dropout),
        nn.Linear(intermediate_size, hidden_size, bias=use_bias),
    )