| def get_model(model_config, task=''): | |
| if '/vit/' in model_config.yaml_path: | |
| from .vit import load_model as load_vit_model | |
| model = load_vit_model(model_config) | |
| print('Loaded ViT model') | |
| elif '/vit_irpe/' in model_config.yaml_path: | |
| from .vit_irpe import load_model as load_vit_irpe_model | |
| model = load_vit_irpe_model(model_config) | |
| print('Loaded ViT model with iRPE') | |
| elif '/vit_kprpe/' in model_config.yaml_path: | |
| from .vit_kprpe import load_model as load_vit_kprpe_model | |
| model = load_vit_kprpe_model(model_config) | |
| print('Loaded ViT model with KPRPE') | |
| elif '/iresnet/' in model_config.yaml_path: | |
| from .iresnet import load_model as load_iresnet_model | |
| model = load_iresnet_model(model_config) | |
| print('Loaded iResNet model') | |
| elif '/iresnet_insightface/' in model_config.yaml_path: | |
| from .iresnet_insightface import load_model as load_iresnet_insightface_model | |
| model = load_iresnet_insightface_model(model_config) | |
| print('Loaded iResNet model') | |
| elif '/part_fvit/' in model_config.yaml_path: | |
| from .part_fvit import load_model as load_part_fvit_model | |
| model = load_part_fvit_model(model_config) | |
| print('Loaded PartFVIT model') | |
| elif '/swin/' in model_config.yaml_path: | |
| from .swin import load_model as load_swin_model | |
| model = load_swin_model(model_config) | |
| print('Loaded Swin model') | |
| elif '/swin_kprpe/' in model_config.yaml_path: | |
| from .swin_kprpe import load_model as load_swin_kprpe_model | |
| model = load_swin_kprpe_model(model_config) | |
| print('Loaded Swin model with KPRPE') | |
| else: | |
| raise NotImplementedError(f"Model {model_config.yaml_path} not implemented") | |
| if model_config.start_from: | |
| model.load_state_dict_from_path(model_config.start_from) | |
| if model_config.freeze: | |
| for param in model.parameters(): | |
| param.requires_grad = False | |
| return model | |