| dependencies = ['torch'] | |
| import torch | |
| import os | |
| def DummyNet(pretrained=True, **kwargs): | |
| """Entry point for the DummyNet model.""" | |
| from dummy import SimpleFeedForwardNet | |
| model = SimpleFeedForwardNet(**kwargs) | |
| if pretrained: | |
| hub_dir = os.path.dirname(os.path.abspath(__file__)) | |
| weight_path = os.path.join(hub_dir, 'dummy-weights.bin') | |
| model.load_state_dict(torch.load(weight_path)) | |
| return model | |
| def VanillaNet(pretrained=True, **kwargs): | |
| """Entry point for the VanillaNet model.""" | |
| from vanilla import SimpleFeedForwardNet | |
| model = SimpleFeedForwardNet(**kwargs) | |
| if pretrained: | |
| hub_dir = os.path.dirname(os.path.abspath(__file__)) | |
| weight_path = os.path.join(hub_dir, 'vanilla-weight.bin') | |
| model.load_state_dict(torch.load(weight_path)) | |
| return model |