from torch import nn import torch import torch.nn.functional as F from torchvision.models import resnet18 class ResNet(nn.Module): def __init__(self): super().__init__() self.model = resnet18(num_classes=10) self.model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False) def forward(self, x): return self.model(x)