xinjie.wang
update
7734c01
# Copyright (c) Meta Platforms, Inc. and affiliates.
from torch import nn
import torch.nn.functional as F
from typing import Optional
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
multiple_of: int = 256,
ffn_dim_multiplier: Optional[float] = None,
output_dim: Optional[int] = None,
skip_w2: bool = False,
):
"""
Llama3 FeedForward layer
https://github.com/meta-llama/llama3/blob/a0940f9cf7065d45bb6675660f80d305c041a754/llama/model.py#L193
"""
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
# custom dim factor multiplier
if ffn_dim_multiplier is not None:
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
if output_dim is None:
output_dim = dim
self.skip_w2 = skip_w2
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
if not self.skip_w2:
self.w2 = nn.Linear(hidden_dim, output_dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
def forward(self, x):
x = F.silu(self.w1(x)) * self.w3(x)
if self.skip_w2:
return x
return self.w2(x)