stegsoph's picture
Upload folder using huggingface_hub
5d2c747 verified
"""
A collection of FFN blocks
"""
import torch
import torch.nn.functional as F
from models.components.layers.activations import build_activation
class GenericFFN(torch.nn.Module):
"""
A simple feedforward network
"""
def __init__(
self,
hidden_dim,
ffn_dim,
bias,
ffn_activation,
):
super().__init__()
# build the ffn block
self.linear_1 = torch.nn.Linear(hidden_dim, ffn_dim, bias=bias)
self.activation = build_activation(activation_name=ffn_activation)
self.linear_2 = torch.nn.Linear(ffn_dim, hidden_dim, bias=bias)
def forward(self, x):
"""
A simple forward pass through the FFN
"""
x = self.linear_1(x)
x = self.activation(x)
x = self.linear_2(x)
return x
class SwiGLUFFN(torch.nn.Module):
"""
Implementation based on:
https://github.com/meta-llama/llama3/blob/main/llama/model.py
originally from https://arxiv.org/abs/2002.05202
N.B. does not support dropout
"""
def __init__(
self,
hidden_dim,
ffn_dim,
bias,
):
super().__init__()
# build the linear functions
self.linear_1 = torch.nn.Linear(hidden_dim, ffn_dim, bias=bias)
self.linear_2 = torch.nn.Linear(ffn_dim, hidden_dim, bias=bias)
self.linear_3 = torch.nn.Linear(hidden_dim, ffn_dim, bias=bias)
def forward(self, x):
"""
A simple forward pass through the FFN
"""
return self.linear_2(F.silu(self.linear_1(x)) * self.linear_3(x))
FFN_DICT = {
"generic": lambda hidden_dim, ffn_cfg: GenericFFN(
hidden_dim=hidden_dim,
ffn_dim=ffn_cfg["ffn_dim"],
bias=ffn_cfg["bias"],
ffn_activation=ffn_cfg["activation"],
),
"swiglu": lambda hidden_dim, ffn_cfg: SwiGLUFFN(
hidden_dim=hidden_dim,
ffn_dim=ffn_cfg["ffn_dim"],
bias=ffn_cfg["bias"],
),
}
def build_ffn(hidden_dim, ffn_cfg):
"""
Build a feedforward network
"""
return FFN_DICT[ffn_cfg["ffn_type"]](hidden_dim=hidden_dim, ffn_cfg=ffn_cfg)