File size: 248 Bytes
2dfdcd4
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
import torch
from model import CNNModel

def AlexNet(pretrained=True):
    model = CNNModel()

    if pretrained:
        state_dict = torch.load("alexnet_weights.pth", map_location="cpu")
        model.load_state_dict(state_dict)

    return model