| import random |
| import numpy as np |
| import torch as tr |
|
|
|
|
| seed = 4313204 |
| random.seed(seed) |
| np.random.seed(seed) |
| tr.manual_seed(seed) |
|
|
|
|
| class SimpleLinearNet(tr.nn.Module): |
| def __init__(self): |
| super().__init__() |
|
|
| |
| self.linear1 = tr.nn.Linear(784, 512) |
|
|
| self.linear2 = tr.nn.Linear(512, 512) |
|
|
| self.linear3 = tr.nn.Linear(512, 128) |
|
|
| self.linear4 = tr.nn.Linear(128, 10) |
|
|
| |
| self.init_weights() |
|
|
| def init_weights(self): |
| for layer in [self.linear1,self.linear2 ,self.linear3]: |
| tr.nn.init.kaiming_normal_(layer.weight,nonlinearity='relu') |
| |
| tr.nn.init.zeros_(layer.bias) |
| tr.nn.init.xavier_uniform_(self.linear4.weight) |
| tr.nn.init.zeros_(self.linear4.bias) |
|
|
|
|
| def forward(self, x): |
| x = tr.relu(self.linear1(x)) |
| x = tr.relu(self.linear2(x)) |
| x = tr.relu(self.linear3(x)) |
| x = self.linear4(x) |
| return x |
|
|
|
|
|
|
|
|