File size: 1,029 Bytes
768ca02
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40

import json, torch, timm
from PIL import Image
from safetensors.torch import load_file
from torchvision import transforms

MODEL_NAME = "vit_base_patch16_224"
IMG_SIZE = 224
MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]

def load_model(repo_dir="."):
    with open(f"{repo_dir}/config.json") as f:
        cfg = json.load(f)

    model = timm.create_model(MODEL_NAME, pretrained=False, num_classes=cfg["num_labels"])
    state = load_file(f"{repo_dir}/model.safetensors")
    model.load_state_dict(state)
    model.eval()
    return model, cfg

def predict(image_path, repo_dir="."):
    model, cfg = load_model(repo_dir)

    tfm = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(IMG_SIZE),
        transforms.ToTensor(),
        transforms.Normalize(MEAN, STD),
    ])

    img = Image.open(image_path).convert("RGB")
    x = tfm(img).unsqueeze(0)

    with torch.no_grad():
        logits = model(x)
        pred = logits.argmax(-1).item()

    return cfg["id2label"][str(pred)]