File size: 473 Bytes
6733993
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import torch.nn as nn

class MultiplicationNet(nn.Module):
    def __init__(self):
        super(MultiplicationNet, self).__init__()
        layer_sizes = [2, 64, 128, 256, 1]
        layers = []
        for i in range(len(layer_sizes) - 1):
            layers.append(nn.Linear(layer_sizes[i], layer_sizes[i + 1]))
            layers.append(nn.ReLU())
        layers.pop()
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)