| import torch.nn as nn | |
| class MultiplicationNet(nn.Module): | |
| def __init__(self): | |
| super(MultiplicationNet, self).__init__() | |
| layer_sizes = [2, 512, 1024, 2048, 1] | |
| layers = [] | |
| for i in range(len(layer_sizes) - 1): | |
| layers.append(nn.Linear(layer_sizes[i], layer_sizes[i + 1])) | |
| layers.append(nn.ReLU()) | |
| layers.pop() | |
| self.model = nn.Sequential(*layers) | |
| def forward(self, x): | |
| return self.model(x) | |