File size: 757 Bytes
377dccd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import importlib
def get_all_models():
return [model.split('.')[0] for model in os.listdir('models')
if not model.find('__') > -1 and 'py' in model]
names = {}
for model in get_all_models():
mod = importlib.import_module('models.' + model)
class_name = {x.lower():x for x in mod.__dir__()}[model.replace('_', '')]
names[model] = getattr(mod, class_name)
def get_model(args, backbone, loss, transform):
return names[args.model](backbone, loss, args, transform)
|