File size: 770 Bytes
cbd7eaa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20

from transformers import AutoImageProcessor, AutoModelForImageClassification
import torch

def load_model():
    processor = AutoImageProcessor.from_pretrained("microsoft/beit-base-patch16-224")
    model = AutoModelForImageClassification.from_pretrained("microsoft/beit-base-patch16-224")
    return processor, model

def classify_food(image, processor, model):
    from PIL import Image
    import numpy as np
    inputs = processor(images=Image.fromarray(np.array(image)), return_tensors="pt")
    with torch.no_grad():
        logits = model(**inputs).logits
    predicted_class_idx = logits.argmax(-1).item()
    label = model.config.id2label[predicted_class_idx]
    confidence = logits.softmax(dim=-1)[0, predicted_class_idx].item()
    return label, confidence