yagizdevre commited on
Commit
ffefe3d
·
1 Parent(s): a78f7b3
Files changed (2) hide show
  1. __pycache__/mlp.cpython-312.pyc +0 -0
  2. mlp.py +10 -15
__pycache__/mlp.cpython-312.pyc CHANGED
Binary files a/__pycache__/mlp.cpython-312.pyc and b/__pycache__/mlp.cpython-312.pyc differ
 
mlp.py CHANGED
@@ -2,26 +2,21 @@ import torch.nn as nn
2
  from torch.nn import functional as F
3
  import torch
4
  class MLP(nn.Module):
5
- def __init__(self, config, dtype=None):
6
  # https://arxiv.org/pdf/2002.05202
7
  super().__init__()
8
- torch_dtype = config.torch_dtype
9
- dtype = dtype if dtype is not None else torch_dtype
10
- self.hidden_size = config.n_embd
11
- self.intermediate_size = config.n_embd * config.mlp_scale
12
- self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.bias, dtype=torch.bfloat16)
13
- self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.bias, dtype=torch.bfloat16)
14
- self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.bias, dtype=torch.bfloat16)
15
  self.dropout = nn.Dropout(config.dropout)
16
-
17
  def forward(self, x):
18
- dtype = self.gate_proj.weight.dtype # Match the dtype of projection layers
19
- x = x.to(dtype=dtype) # Convert input to the same dtype
20
- x = x.to(self.gate_proj.weight.dtype)
21
  gate = self.gate_proj(x)
22
- gate = F.gelu(gate, approximate="tanh").to(dtype=dtype)
23
- up = self.up_proj(x).to(dtype=dtype)
24
  fuse = gate * up
25
- outputs = self.down_proj(fuse).to(dtype=dtype)
26
  outputs = self.dropout(outputs)
27
  return outputs
 
2
  from torch.nn import functional as F
3
  import torch
4
  class MLP(nn.Module):
5
+ def __init__(self, config):
6
  # https://arxiv.org/pdf/2002.05202
7
  super().__init__()
8
+ self.hidden_size = config.dim
9
+ self.intermediate_size = config.dim * config.mlp_scale
10
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.bias)
11
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.bias)
12
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.bias)
 
 
13
  self.dropout = nn.Dropout(config.dropout)
14
+
15
  def forward(self, x):
 
 
 
16
  gate = self.gate_proj(x)
17
+ gate = F.gelu(gate, approximate="tanh")
18
+ up = self.up_proj(x)
19
  fuse = gate * up
20
+ outputs = self.down_proj(fuse)
21
  outputs = self.dropout(outputs)
22
  return outputs