File size: 710 Bytes
5ad0794
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch
import timm

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

def load_model(model_path, model_type="vit"):
    if model_type == "clip":
        model = timm.create_model("vit_base_patch16_clip_224.openai", pretrained=False, num_classes=1)
    else:
        model = timm.create_model("vit_base_patch16_224", pretrained=False, num_classes=1)

    checkpoint = torch.load(model_path, map_location="cpu")

    if isinstance(checkpoint, dict):
        state_dict = checkpoint.get("model_state_dict", checkpoint.get("state_dict", checkpoint))
    else:
        state_dict = checkpoint

    model.load_state_dict(state_dict, strict=False)

    model.to(DEVICE)
    model.eval()

    return model