from transformers import AutoModelForImageClassification, AutoImageProcessor import torch from torchvision import transforms, models from functions import import_class_labels device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Using device {device} for classification") model_img_size = (224, 224) class_labels = import_class_labels('./') # Load trained model and feature extractor model_name = "paddeh/is-it-max" print(f"Loading classifier model {model_name}") model = AutoModelForImageClassification.from_pretrained(model_name) \ .to(device) \ .eval() processor = AutoImageProcessor.from_pretrained(model_name, use_fast=True) # Define image transformations transform = transforms.Compose([ transforms.Resize(model_img_size, interpolation=transforms.InterpolationMode.BICUBIC), transforms.ToTensor(), transforms.Normalize(mean=processor.image_mean, std=processor.image_std), ]) def classify(image): input_tensor = transform(image).unsqueeze(0).to(device) with torch.no_grad(): outputs = model(input_tensor) predicted_class_idx = outputs.logits.argmax(-1).item() predicted_label = class_labels[predicted_class_idx] return predicted_class_idx, predicted_label