Spaces:
Sleeping
Sleeping
| from transformers import AutoModelForImageClassification, AutoImageProcessor | |
| import torch | |
| from torchvision import transforms, models | |
| from functions import import_class_labels | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| print(f"Using device {device} for classification") | |
| model_img_size = (224, 224) | |
| class_labels = import_class_labels('./') | |
| # Load trained model and feature extractor | |
| model_name = "paddeh/is-it-max" | |
| print(f"Loading classifier model {model_name}") | |
| model = AutoModelForImageClassification.from_pretrained(model_name) \ | |
| .to(device) \ | |
| .eval() | |
| processor = AutoImageProcessor.from_pretrained(model_name, use_fast=True) | |
| # Define image transformations | |
| transform = transforms.Compose([ | |
| transforms.Resize(model_img_size, interpolation=transforms.InterpolationMode.BICUBIC), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=processor.image_mean, std=processor.image_std), | |
| ]) | |
| def classify(image): | |
| input_tensor = transform(image).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| outputs = model(input_tensor) | |
| predicted_class_idx = outputs.logits.argmax(-1).item() | |
| predicted_label = class_labels[predicted_class_idx] | |
| return predicted_class_idx, predicted_label | |