is-it-max / classification.py
paddeh's picture
visualise-segmentation (#1)
9073e25 verified
from transformers import AutoModelForImageClassification, AutoImageProcessor
import torch
from torchvision import transforms, models
from functions import import_class_labels
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device {device} for classification")
model_img_size = (224, 224)
class_labels = import_class_labels('./')
# Load trained model and feature extractor
model_name = "paddeh/is-it-max"
print(f"Loading classifier model {model_name}")
model = AutoModelForImageClassification.from_pretrained(model_name) \
.to(device) \
.eval()
processor = AutoImageProcessor.from_pretrained(model_name, use_fast=True)
# Define image transformations
transform = transforms.Compose([
transforms.Resize(model_img_size, interpolation=transforms.InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize(mean=processor.image_mean, std=processor.image_std),
])
def classify(image):
input_tensor = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(input_tensor)
predicted_class_idx = outputs.logits.argmax(-1).item()
predicted_label = class_labels[predicted_class_idx]
return predicted_class_idx, predicted_label