SudokuSolver / model /get_model.py
LTPhat's picture
code
1f1fc6b
import torch
from torchvision.models import resnet18, resnet101, resnet50
import torchvision
import torch.nn as nn
def get_model(model_name, pretrained=True):
if model_name == "resnet18":
net = torchvision.models.resnet18(pretrained=pretrained)
# Replace 1st layer to use it on grayscale images
net.conv1 = nn.Conv2d(
1,
64,
kernel_size=(7, 7),
stride=(2, 2),
padding=(3, 3),
bias=False,
)
net.fc = nn.Linear(in_features=2048, out_features=10, bias=True)
if model_name == "resnet50":
net = torchvision.models.resnet50(pretrained=pretrained)
# Replace 1st layer to use it on grayscale images
net.conv1 = nn.Conv2d(
1,
64,
kernel_size=(7, 7),
stride=(2, 2),
padding=(3, 3),
bias=False,
)
net.fc = nn.Linear(in_features=2048, out_features=10, bias=True)
if model_name == "resnet101":
net = torchvision.models.resnet101(pretrained=pretrained)
# Replace 1st layer to use it on grayscale images
net.conv1 = nn.Conv2d(
1,
64,
kernel_size=(7, 7),
stride=(2, 2),
padding=(3, 3),
bias=False,
)
net.fc = nn.Linear(in_features=2048, out_features=10, bias=True)
return net