| import torch | |
| from model import CNNModel | |
| def AlexNet(pretrained=True): | |
| model = CNNModel() | |
| if pretrained: | |
| state_dict = torch.load("alexnet_weights.pth", map_location="cpu") | |
| model.load_state_dict(state_dict) | |
| return model |
| import torch | |
| from model import CNNModel | |
| def AlexNet(pretrained=True): | |
| model = CNNModel() | |
| if pretrained: | |
| state_dict = torch.load("alexnet_weights.pth", map_location="cpu") | |
| model.load_state_dict(state_dict) | |
| return model |