A4_4320232_Naser / hubconf.py
Naser9929's picture
Upload 6 files
8644d42 verified
dependencies = ['torch']
import torch
import os
def DummyNet(pretrained=True, **kwargs):
"""Entry point for the DummyNet model."""
from dummy import SimpleFeedForwardNet
model = SimpleFeedForwardNet(**kwargs)
if pretrained:
hub_dir = os.path.dirname(os.path.abspath(__file__))
weight_path = os.path.join(hub_dir, 'dummy-weights.bin')
model.load_state_dict(torch.load(weight_path))
return model
def VanillaNet(pretrained=True, **kwargs):
"""Entry point for the VanillaNet model."""
from vanilla import SimpleFeedForwardNet
model = SimpleFeedForwardNet(**kwargs)
if pretrained:
hub_dir = os.path.dirname(os.path.abspath(__file__))
weight_path = os.path.join(hub_dir, 'vanilla-weight.bin')
model.load_state_dict(torch.load(weight_path))
return model