File size: 1,243 Bytes
9073e25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from transformers import AutoModelForImageClassification, AutoImageProcessor
import torch
from torchvision import transforms, models

from functions import import_class_labels

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device {device} for classification")

model_img_size = (224, 224)
class_labels = import_class_labels('./')

# Load trained model and feature extractor
model_name = "paddeh/is-it-max"
print(f"Loading classifier model {model_name}")
model = AutoModelForImageClassification.from_pretrained(model_name) \
    .to(device) \
    .eval()
processor = AutoImageProcessor.from_pretrained(model_name, use_fast=True)

# Define image transformations
transform = transforms.Compose([
    transforms.Resize(model_img_size, interpolation=transforms.InterpolationMode.BICUBIC),
    transforms.ToTensor(),
    transforms.Normalize(mean=processor.image_mean, std=processor.image_std),
])


def classify(image):
    input_tensor = transform(image).unsqueeze(0).to(device)

    with torch.no_grad():
        outputs = model(input_tensor)

    predicted_class_idx = outputs.logits.argmax(-1).item()
    predicted_label = class_labels[predicted_class_idx]

    return predicted_class_idx, predicted_label