| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from common.utils import LOGGER |
| from torchvision.models import MobileNetV2 |
|
|
| from common.registries.model_registry import MODEL_WRAPPER_REGISTRY |
| from image_classification.pt.src.models.edgevit.edgevit import (edgevit_s, edgevit_xs, |
| edgevit_xxs) |
| from image_classification.pt.src.models.fasternet.fasternet import (fasternet_l, |
| fasternet_m, |
| fasternet_s, |
| fasternet_t0, |
| fasternet_t1, |
| fasternet_t2) |
| from image_classification.pt.src.models.stresnet import STResNetMicro, STResNetMilli, STResNetNano ,STResNetPico, STResNetTiny |
| from image_classification.pt.wrappers.models.utils import load_checkpoint_ic |
| from image_classification.pt.wrappers.models.checkpoints import MODEL_CHECKPOINTS |
|
|
| __all__ = [] |
|
|
|
|
| MODEL_FNS = { |
| 'edgevit_s_pt': (edgevit_s, {}), |
| 'edgevit_xs_pt': (edgevit_xs, {}), |
| 'edgevit_xxs_pt': (edgevit_xxs, {}), |
| 'fasternet_t0_pt': (fasternet_t0, {}), |
| 'fasternet_t1_pt': (fasternet_t1, {}), |
| 'fasternet_t2_pt': (fasternet_t2, {}), |
| 'fasternet_s_pt': (fasternet_s, {}), |
| 'fasternet_m_pt': (fasternet_m, {}), |
| 'fasternet_l_pt': (fasternet_l, {}), |
| 'mobilenetv2_w035_pt': (MobileNetV2, {'width_mult': 0.35}), |
| 'st_resnetmicro_actrelu_pt': (STResNetMicro, {}), |
| 'st_resnetmilli_actrelu_pt': (STResNetMilli, {}), |
| 'st_resnetnano_actrelu_pt': (STResNetNano, {}), |
| 'st_resnetpico_actrelu_pt': (STResNetPico, {}), |
| 'st_resnettiny_actrelu_pt': (STResNetTiny, {}), |
| } |
|
|
| def register_model_wrapper(model_fn, model_name, **model_init_kwargs): |
|
|
| def get_model(cfg): |
| model = model_fn(num_classes=cfg.dataset.num_classes, **model_init_kwargs) |
| if cfg.model.pretrained: |
| |
| model = load_checkpoint_ic(model, cfg) |
| return model |
|
|
| get_model = MODEL_WRAPPER_REGISTRY.register( |
| framework='torch', |
| model_name=model_name, |
| use_case="image_classification", |
| has_checkpoint = any(k.startswith(model_name) for k in MODEL_CHECKPOINTS) |
| )(get_model) |
|
|
| get_model.__name__ = f'{model_name}' |
| return get_model |
|
|
|
|
| for _model_name, (_model_fn, _model_kwargs) in MODEL_FNS.items(): |
| globals()[_model_name] = register_model_wrapper(_model_fn, _model_name, **_model_kwargs) |
| __all__.append(_model_name) |