import torch from model import CNNModel def AlexNet(pretrained=True): model = CNNModel() if pretrained: state_dict = torch.load("alexnet_weights.pth", map_location="cpu") model.load_state_dict(state_dict) return model