| from .config import set_layer_config | |
| from .helpers import load_checkpoint | |
| from .gen_efficientnet import * | |
| from .mobilenetv3 import * | |
| def create_model( | |
| model_name='mnasnet_100', | |
| pretrained=None, | |
| num_classes=1000, | |
| in_chans=3, | |
| checkpoint_path='', | |
| **kwargs): | |
| model_kwargs = dict(num_classes=num_classes, in_chans=in_chans, pretrained=pretrained, **kwargs) | |
| if model_name in globals(): | |
| create_fn = globals()[model_name] | |
| model = create_fn(**model_kwargs) | |
| else: | |
| raise RuntimeError('Unknown model (%s)' % model_name) | |
| if checkpoint_path and not pretrained: | |
| load_checkpoint(model, checkpoint_path) | |
| return model | |