| import random
|
| import numpy as np
|
| import torch as tr
|
|
|
| seed = 4310823
|
| random.seed(seed)
|
| np.random.seed(seed)
|
| tr.manual_seed(seed)
|
|
|
| class SimpleFeedForwardNet(tr.nn.Module):
|
| def __init__(self):
|
| super().__init__()
|
|
|
| self.linear1 = tr.nn.Linear(784, 512)
|
| self.linear2 = tr.nn.Linear(512, 784)
|
| self.linear3 = tr.nn.Linear(784, 10)
|
|
|
| tr.nn.init.xavier_uniform_(self.linear1.weight)
|
| tr.nn.init.xavier_uniform_(self.linear2.weight)
|
| tr.nn.init.xavier_uniform_(self.linear3.weight)
|
|
|
| tr.nn.init.zeros_(self.linear1.bias)
|
| tr.nn.init.zeros_(self.linear2.bias)
|
| tr.nn.init.zeros_(self.linear3.bias)
|
|
|
| def forward(self, x):
|
| x = self.linear1(x)
|
| x = self.linear2(x)
|
| x = self.linear3(x)
|
| return x
|
|
|
|
|
|
|
|
|
|
|
| model=SimpleFeedForwardNet()
|
|
|
| optimizer = tr.optim.SGD(
|
| model.parameters(),
|
| lr=0.001,
|
| momentum=0.9,
|
| weight_decay=5e-4
|
| )
|
|
|
|
|
|
|
|
|
| loss_fn = tr.nn.CrossEntropyLoss()
|
|
|