File size: 932 Bytes
14f592d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17c0b30
14f592d
 
17c0b30
 
 
 
14f592d
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch
import torch.nn as nn
import torch.nn.functional as F


class GELU(nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self,x):
        return 0.5 *  x * (
            1+torch.tanh(
                torch.sqrt(torch.tensor(2/torch.pi))*(x+0.044715*torch.pow(x,3))
                )
        )


class UstaMLP(nn.Module):
    def __init__(self,embedding_dim,hidden_dim,device="cpu"):
        super().__init__()

        self.gate_proj = nn.Linear(embedding_dim,hidden_dim,device=device)
        self.up_proj = nn.Linear(embedding_dim,hidden_dim,device=device)
        self.down_proj = nn.Linear(hidden_dim,embedding_dim,device=device)
        self.gelu = GELU().to(device)
    
    def forward(self,x):
        gate = self.gate_proj(x)
        gate = self.gelu(gate)
        up = self.up_proj(x)
        fuse = gate*up
        outputs = self.down_proj(fuse)
        return outputs