File size: 522 Bytes
0a6c6ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import torch
import timm
from safetensors.torch import load_file

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

def load_model(path, arch="vit_base_patch16_224"):
    state_dict = load_file(path)

    # auto detect output classes
    last_key = list(state_dict.keys())[-1]
    out_features = state_dict[last_key].shape[0]

    model = timm.create_model(arch, pretrained=False, num_classes=out_features)
    model.load_state_dict(state_dict, strict=False)

    model.to(DEVICE)
    model.eval()

    return model