# Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import os import importlib def get_all_models(): return [model.split('.')[0] for model in os.listdir('models') if not model.find('__') > -1 and 'py' in model] names = {} for model in get_all_models(): mod = importlib.import_module('models.' + model) class_name = {x.lower():x for x in mod.__dir__()}[model.replace('_', '')] names[model] = getattr(mod, class_name) def get_model(args, backbone, loss, transform): return names[args.model](backbone, loss, args, transform)