Spaces:
Sleeping
Sleeping
File size: 890 Bytes
e2f7ccb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 |
import torch
from torch import nn
class MNISTnet(nn.Module):
def __init__(self, input_channels, num_labels, hidden_layers):
super().__init__()
self.block_one = nn.Sequential(
nn.Conv2d(in_channels=input_channels, out_channels=hidden_layers, kernel_size=3, stride=1, padding='same'),
nn.ReLU(),
)
self.block_two = nn.Sequential(
nn.Conv2d(in_channels=hidden_layers, out_channels=num_labels, kernel_size=3, stride=1, padding='same'),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2)
)
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(in_features = num_labels*14*14, out_features=10, bias=True)
)
def forward(self, x):
x = self.block_one(x)
x = self.block_two(x)
x = self.classifier(x)
return x |