| import os |
|
|
| from torch.utils import model_zoo |
|
|
| from classes.fc4.squeezenet.SqueezeNet import SqueezeNet |
|
|
| model_urls = { |
| 1.0: 'https://download.pytorch.org/models/squeezenet1_0-a815701f.pth', |
| 1.1: 'https://download.pytorch.org/models/squeezenet1_1-f364aa15.pth', |
| } |
|
|
|
|
| class SqueezeNetLoader: |
| def __init__(self, version: float = 1.1): |
| self.__version = version |
| self.__model = SqueezeNet(self.__version) |
|
|
| def load(self, pretrained: bool = False) -> SqueezeNet: |
| """ |
| Returns the specified version of SqueezeNet |
| @param pretrained: if True, returns a model pre-trained on ImageNet |
| """ |
| if pretrained: |
| path_to_local = os.path.join("assets", "pretrained") |
| os.environ['TORCH_HOME'] = path_to_local |
| self.__model.load_state_dict(model_zoo.load_url(model_urls[self.__version])) |
| return self.__model |
|
|