2dfdcd4
1
2
3
4
5
6
7
8
9
10
11
import torch from model import CNNModel def AlexNet(pretrained=True): model = CNNModel() if pretrained: state_dict = torch.load("alexnet_weights.pth", map_location="cpu") model.load_state_dict(state_dict) return model