A4_4310823_Sema / dummy.py.py
SemaAli99's picture
Upload 2 files
9df2737 verified
import random
import numpy as np
import torch as tr
seed = 4310823 # <<<<<<<<<<<<<<<< Your UPM ID Goes Here
random.seed(seed)
np.random.seed(seed)
tr.manual_seed(seed)
class SimpleFeedForwardNet(tr.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = tr.nn.Linear(784, 512)
self.linear2 = tr.nn.Linear(512, 784)
self.linear3 = tr.nn.Linear(784, 10)
tr.nn.init.xavier_uniform_(self.linear1.weight)
tr.nn.init.xavier_uniform_(self.linear2.weight)
tr.nn.init.xavier_uniform_(self.linear3.weight)
tr.nn.init.zeros_(self.linear1.bias)
tr.nn.init.zeros_(self.linear2.bias)
tr.nn.init.zeros_(self.linear3.bias)
def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
x = self.linear3(x)
return x
model=SimpleFeedForwardNet()
optimizer = tr.optim.SGD(
model.parameters(),
lr=0.001,
momentum=0.9,
weight_decay=5e-4
)
loss_fn = tr.nn.CrossEntropyLoss() # Objective Function [DO NOT CHANGE !]