| import torch |
| import timm |
|
|
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| def load_model(model_path, model_type="vit"): |
| if model_type == "clip": |
| model = timm.create_model("vit_base_patch16_clip_224.openai", pretrained=False, num_classes=1) |
| else: |
| model = timm.create_model("vit_base_patch16_224", pretrained=False, num_classes=1) |
|
|
| checkpoint = torch.load(model_path, map_location="cpu") |
|
|
| if isinstance(checkpoint, dict): |
| state_dict = checkpoint.get("model_state_dict", checkpoint.get("state_dict", checkpoint)) |
| else: |
| state_dict = checkpoint |
|
|
| model.load_state_dict(state_dict, strict=False) |
|
|
| model.to(DEVICE) |
| model.eval() |
|
|
| return model |