AlexNet / hubconf.py
DrujZ-cmd's picture
AI417 A5 AlexNet model
2dfdcd4
import torch
from model import CNNModel
def AlexNet(pretrained=True):
model = CNNModel()
if pretrained:
state_dict = torch.load("alexnet_weights.pth", map_location="cpu")
model.load_state_dict(state_dict)
return model