| from src.models.aero import Aero | |
| from src.models.seanet import Seanet | |
| from yaml import safe_load | |
| def get_model(model_name="aero", experiment_file="aero_441-441_512_256.yaml"): | |
| if model_name == 'aero': | |
| with open("conf/experiment/" + experiment_file) as f: | |
| generator = Aero(**safe_load(f)["aero"]) | |
| elif model_name == 'seanet': | |
| with open("conf/experiment/" + experiment_file) as f: | |
| generator = Seanet(**safe_load(f)["seanet"]) | |
| models = {'generator': generator} | |
| return models |