File size: 622 Bytes
cde085e dfefcfd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super().__init__()
self.ll1 = nn.Linear(768, 1024)
self.bn1 = nn.BatchNorm1d(2)
self.elu1 = nn.ELU()
self.ll2 = nn.Linear(1024, 512)
self.bn2 = nn.BatchNorm1d(2)
self.elu2 = nn.ELU()
self.llf = nn.Linear(512, 1)
def forward(self, x):
x = self.elu1(self.bn1(self.ll1(x)))
x = self.elu2(self.bn2(self.ll2(x)))
x = torch.sum(x, dim=1)
x = self.llf(x)
return x
if __name__ == '__main__':
model = torch.load('model.pth') |