File size: 622 Bytes
cde085e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dfefcfd
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch
import torch.nn as nn

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.ll1 = nn.Linear(768, 1024)
        self.bn1 = nn.BatchNorm1d(2)
        self.elu1 = nn.ELU()
        self.ll2 = nn.Linear(1024, 512)
        self.bn2 = nn.BatchNorm1d(2)
        self.elu2 = nn.ELU()
        self.llf = nn.Linear(512, 1)
        
    def forward(self, x):
        x = self.elu1(self.bn1(self.ll1(x)))
        x = self.elu2(self.bn2(self.ll2(x)))
        x = torch.sum(x, dim=1)
        x = self.llf(x)
        return x


if __name__ == '__main__':
    model = torch.load('model.pth')