A4_4313204_Fardows / vanilla.py
Fardows11's picture
Update vanilla.py
a143d73 verified
import random
import numpy as np
import torch as tr
seed = 4313204
random.seed(seed)
np.random.seed(seed)
tr.manual_seed(seed)
class SimpleLinearNet(tr.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = tr.nn.Linear(784, 512)
self.linear2 = tr.nn.Linear(512, 512)
self.linear3 = tr.nn.Linear(512, 128)
self.linear4 = tr.nn.Linear(128, 10)
self.init_weights()
def init_weights(self):
for layer in [self.linear1,self.linear2 ,self.linear3]:
tr.nn.init.kaiming_normal_(layer.weight,nonlinearity='relu')
#tr.nn.init.normal_(layer.bias, mean=0.01, std=0.02)
tr.nn.init.zeros_(layer.bias)
tr.nn.init.xavier_uniform_(self.linear4.weight)
tr.nn.init.zeros_(self.linear4.bias)
def forward(self, x):
x = tr.relu(self.linear1(x))
x = tr.relu(self.linear2(x))
x = tr.relu(self.linear3(x))
x = self.linear4(x)
return x