File size: 2,155 Bytes
5d2c747 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
"""
A collection of FFN blocks
"""
import torch
import torch.nn.functional as F
from models.components.layers.activations import build_activation
class GenericFFN(torch.nn.Module):
"""
A simple feedforward network
"""
def __init__(
self,
hidden_dim,
ffn_dim,
bias,
ffn_activation,
):
super().__init__()
# build the ffn block
self.linear_1 = torch.nn.Linear(hidden_dim, ffn_dim, bias=bias)
self.activation = build_activation(activation_name=ffn_activation)
self.linear_2 = torch.nn.Linear(ffn_dim, hidden_dim, bias=bias)
def forward(self, x):
"""
A simple forward pass through the FFN
"""
x = self.linear_1(x)
x = self.activation(x)
x = self.linear_2(x)
return x
class SwiGLUFFN(torch.nn.Module):
"""
Implementation based on:
https://github.com/meta-llama/llama3/blob/main/llama/model.py
originally from https://arxiv.org/abs/2002.05202
N.B. does not support dropout
"""
def __init__(
self,
hidden_dim,
ffn_dim,
bias,
):
super().__init__()
# build the linear functions
self.linear_1 = torch.nn.Linear(hidden_dim, ffn_dim, bias=bias)
self.linear_2 = torch.nn.Linear(ffn_dim, hidden_dim, bias=bias)
self.linear_3 = torch.nn.Linear(hidden_dim, ffn_dim, bias=bias)
def forward(self, x):
"""
A simple forward pass through the FFN
"""
return self.linear_2(F.silu(self.linear_1(x)) * self.linear_3(x))
FFN_DICT = {
"generic": lambda hidden_dim, ffn_cfg: GenericFFN(
hidden_dim=hidden_dim,
ffn_dim=ffn_cfg["ffn_dim"],
bias=ffn_cfg["bias"],
ffn_activation=ffn_cfg["activation"],
),
"swiglu": lambda hidden_dim, ffn_cfg: SwiGLUFFN(
hidden_dim=hidden_dim,
ffn_dim=ffn_cfg["ffn_dim"],
bias=ffn_cfg["bias"],
),
}
def build_ffn(hidden_dim, ffn_cfg):
"""
Build a feedforward network
"""
return FFN_DICT[ffn_cfg["ffn_type"]](hidden_dim=hidden_dim, ffn_cfg=ffn_cfg)
|