|
|
""" |
|
|
A collection of FFN blocks |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
|
|
|
from models.components.layers.activations import build_activation |
|
|
|
|
|
|
|
|
class GenericFFN(torch.nn.Module): |
|
|
""" |
|
|
A simple feedforward network |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
hidden_dim, |
|
|
ffn_dim, |
|
|
bias, |
|
|
ffn_activation, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.linear_1 = torch.nn.Linear(hidden_dim, ffn_dim, bias=bias) |
|
|
|
|
|
self.activation = build_activation(activation_name=ffn_activation) |
|
|
|
|
|
self.linear_2 = torch.nn.Linear(ffn_dim, hidden_dim, bias=bias) |
|
|
|
|
|
def forward(self, x): |
|
|
""" |
|
|
A simple forward pass through the FFN |
|
|
""" |
|
|
x = self.linear_1(x) |
|
|
x = self.activation(x) |
|
|
x = self.linear_2(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class SwiGLUFFN(torch.nn.Module): |
|
|
""" |
|
|
Implementation based on: |
|
|
https://github.com/meta-llama/llama3/blob/main/llama/model.py |
|
|
originally from https://arxiv.org/abs/2002.05202 |
|
|
|
|
|
N.B. does not support dropout |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
hidden_dim, |
|
|
ffn_dim, |
|
|
bias, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.linear_1 = torch.nn.Linear(hidden_dim, ffn_dim, bias=bias) |
|
|
|
|
|
self.linear_2 = torch.nn.Linear(ffn_dim, hidden_dim, bias=bias) |
|
|
|
|
|
self.linear_3 = torch.nn.Linear(hidden_dim, ffn_dim, bias=bias) |
|
|
|
|
|
def forward(self, x): |
|
|
""" |
|
|
A simple forward pass through the FFN |
|
|
""" |
|
|
return self.linear_2(F.silu(self.linear_1(x)) * self.linear_3(x)) |
|
|
|
|
|
|
|
|
FFN_DICT = { |
|
|
"generic": lambda hidden_dim, ffn_cfg: GenericFFN( |
|
|
hidden_dim=hidden_dim, |
|
|
ffn_dim=ffn_cfg["ffn_dim"], |
|
|
bias=ffn_cfg["bias"], |
|
|
ffn_activation=ffn_cfg["activation"], |
|
|
), |
|
|
"swiglu": lambda hidden_dim, ffn_cfg: SwiGLUFFN( |
|
|
hidden_dim=hidden_dim, |
|
|
ffn_dim=ffn_cfg["ffn_dim"], |
|
|
bias=ffn_cfg["bias"], |
|
|
), |
|
|
} |
|
|
|
|
|
|
|
|
def build_ffn(hidden_dim, ffn_cfg): |
|
|
""" |
|
|
Build a feedforward network |
|
|
""" |
|
|
return FFN_DICT[ffn_cfg["ffn_type"]](hidden_dim=hidden_dim, ffn_cfg=ffn_cfg) |
|
|
|