File size: 893 Bytes
cbda9b7
 
a78f7b3
cbda9b7
ffefe3d
cbda9b7
 
ffefe3d
 
 
 
 
cbda9b7
ffefe3d
cbda9b7
 
ffefe3d
 
cbda9b7
ffefe3d
cbda9b7
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import torch.nn as nn
from torch.nn import functional as F
import torch
class MLP(nn.Module):
    def __init__(self, config):
        # https://arxiv.org/pdf/2002.05202
        super().__init__()
        self.hidden_size = config.dim
        self.intermediate_size = config.dim * config.mlp_scale
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.bias)
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.bias)
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.bias)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        gate = self.gate_proj(x)
        gate = F.gelu(gate, approximate="tanh")
        up = self.up_proj(x)
        fuse = gate * up
        outputs = self.down_proj(fuse)
        outputs = self.dropout(outputs)
        return outputs