Commit
·
ffefe3d
1
Parent(s):
a78f7b3
fix
Browse files- __pycache__/mlp.cpython-312.pyc +0 -0
- mlp.py +10 -15
__pycache__/mlp.cpython-312.pyc
CHANGED
|
Binary files a/__pycache__/mlp.cpython-312.pyc and b/__pycache__/mlp.cpython-312.pyc differ
|
|
|
mlp.py
CHANGED
|
@@ -2,26 +2,21 @@ import torch.nn as nn
|
|
| 2 |
from torch.nn import functional as F
|
| 3 |
import torch
|
| 4 |
class MLP(nn.Module):
|
| 5 |
-
def __init__(self, config
|
| 6 |
# https://arxiv.org/pdf/2002.05202
|
| 7 |
super().__init__()
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
self.hidden_size =
|
| 11 |
-
self.
|
| 12 |
-
self.
|
| 13 |
-
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.bias, dtype=torch.bfloat16)
|
| 14 |
-
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.bias, dtype=torch.bfloat16)
|
| 15 |
self.dropout = nn.Dropout(config.dropout)
|
| 16 |
-
|
| 17 |
def forward(self, x):
|
| 18 |
-
dtype = self.gate_proj.weight.dtype # Match the dtype of projection layers
|
| 19 |
-
x = x.to(dtype=dtype) # Convert input to the same dtype
|
| 20 |
-
x = x.to(self.gate_proj.weight.dtype)
|
| 21 |
gate = self.gate_proj(x)
|
| 22 |
-
gate = F.gelu(gate, approximate="tanh")
|
| 23 |
-
up = self.up_proj(x)
|
| 24 |
fuse = gate * up
|
| 25 |
-
outputs = self.down_proj(fuse)
|
| 26 |
outputs = self.dropout(outputs)
|
| 27 |
return outputs
|
|
|
|
| 2 |
from torch.nn import functional as F
|
| 3 |
import torch
|
| 4 |
class MLP(nn.Module):
|
| 5 |
+
def __init__(self, config):
|
| 6 |
# https://arxiv.org/pdf/2002.05202
|
| 7 |
super().__init__()
|
| 8 |
+
self.hidden_size = config.dim
|
| 9 |
+
self.intermediate_size = config.dim * config.mlp_scale
|
| 10 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.bias)
|
| 11 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.bias)
|
| 12 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.bias)
|
|
|
|
|
|
|
| 13 |
self.dropout = nn.Dropout(config.dropout)
|
| 14 |
+
|
| 15 |
def forward(self, x):
|
|
|
|
|
|
|
|
|
|
| 16 |
gate = self.gate_proj(x)
|
| 17 |
+
gate = F.gelu(gate, approximate="tanh")
|
| 18 |
+
up = self.up_proj(x)
|
| 19 |
fuse = gate * up
|
| 20 |
+
outputs = self.down_proj(fuse)
|
| 21 |
outputs = self.dropout(outputs)
|
| 22 |
return outputs
|