import torch.nn as nn from torch.nn import functional as F import torch class MLP(nn.Module): def __init__(self, config): # https://arxiv.org/pdf/2002.05202 super().__init__() self.hidden_size = config.dim self.intermediate_size = config.dim * config.mlp_scale self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.bias) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.bias) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.bias) self.dropout = nn.Dropout(config.dropout) def forward(self, x): gate = self.gate_proj(x) gate = F.gelu(gate, approximate="tanh") up = self.up_proj(x) fuse = gate * up outputs = self.down_proj(fuse) outputs = self.dropout(outputs) return outputs