File size: 530 Bytes
f113387
 
 
 
efc318c
f113387
efc318c
f113387
 
efc318c
f113387
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
from src.models.aero import Aero
from src.models.seanet import Seanet
from yaml import safe_load

def get_model(model_name="aero", experiment_file="aero_441-441_512_256.yaml"):
    if model_name == 'aero':
        with open("conf/experiment/" + experiment_file) as f:
            generator = Aero(**safe_load(f)["aero"])
    elif model_name == 'seanet':
        with open("conf/experiment/" + experiment_file) as f:
            generator = Seanet(**safe_load(f)["seanet"])

    models = {'generator': generator}

    return models