import torch import torchvision from torch import nn # Function for ResNet18. def create_resnet(): weights = torchvision.models.ResNet18_Weights.DEFAULT transforms = weights.transforms() model = torchvision.models.resnet18(weights = weights) for param in model.parameters(): param.requires_grad = False model.fc = nn.Linear(in_features=512, out_features=1, bias=True) return model