draftnewapp / models /convmodel.py
binaychandra's picture
restructed the code
285d843
raw
history blame contribute delete
890 Bytes
import torch
from torch import nn
class MNISTnet(nn.Module):
def __init__(self, input_channels, num_labels, hidden_layers):
super().__init__()
self.block_one = nn.Sequential(
nn.Conv2d(in_channels=input_channels, out_channels=hidden_layers, kernel_size=3, stride=1, padding='same'),
nn.ReLU(),
)
self.block_two = nn.Sequential(
nn.Conv2d(in_channels=hidden_layers, out_channels=num_labels, kernel_size=3, stride=1, padding='same'),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2)
)
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(in_features = num_labels*14*14, out_features=10, bias=True)
)
def forward(self, x):
x = self.block_one(x)
x = self.block_two(x)
x = self.classifier(x)
return x