File size: 1,323 Bytes
7734c01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
# Copyright (c) Meta Platforms, Inc. and affiliates.
from torch import nn
import torch.nn.functional as F
from typing import Optional


class FeedForward(nn.Module):
    def __init__(
        self,
        dim: int,
        hidden_dim: int,
        multiple_of: int = 256,
        ffn_dim_multiplier: Optional[float] = None,
        output_dim: Optional[int] = None,
        skip_w2: bool = False,
    ):
        """
        Llama3 FeedForward layer
            https://github.com/meta-llama/llama3/blob/a0940f9cf7065d45bb6675660f80d305c041a754/llama/model.py#L193
        """
        super().__init__()
        hidden_dim = int(2 * hidden_dim / 3)
        # custom dim factor multiplier
        if ffn_dim_multiplier is not None:
            hidden_dim = int(ffn_dim_multiplier * hidden_dim)
        hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)

        if output_dim is None:
            output_dim = dim

        self.skip_w2 = skip_w2
        self.w1 = nn.Linear(dim, hidden_dim, bias=False)
        if not self.skip_w2:
            self.w2 = nn.Linear(hidden_dim, output_dim, bias=False)
        self.w3 = nn.Linear(dim, hidden_dim, bias=False)

    def forward(self, x):
        x = F.silu(self.w1(x)) * self.w3(x)
        if self.skip_w2:
            return x
        return self.w2(x)