tchauffi's picture
Add MNIST model code and dependencies
d128a86
from torch import nn
import torch
import torch.nn.functional as F
from torchvision.models import resnet18
class ResNet(nn.Module):
def __init__(self):
super().__init__()
self.model = resnet18(num_classes=10)
self.model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
def forward(self, x):
return self.model(x)