sereich's picture
Add phone model (beta), allow models to use different architectures
efc318c
raw
history blame contribute delete
530 Bytes
from src.models.aero import Aero
from src.models.seanet import Seanet
from yaml import safe_load
def get_model(model_name="aero", experiment_file="aero_441-441_512_256.yaml"):
if model_name == 'aero':
with open("conf/experiment/" + experiment_file) as f:
generator = Aero(**safe_load(f)["aero"])
elif model_name == 'seanet':
with open("conf/experiment/" + experiment_file) as f:
generator = Seanet(**safe_load(f)["seanet"])
models = {'generator': generator}
return models