| from transformers import ViTFeatureExtractor, ViTForImageClassification | |
| from PIL import Image | |
| import torch | |
| model_name = "saved_model" | |
| model = ViTForImageClassification.from_pretrained(model_name) | |
| feature_extractor = ViTFeatureExtractor.from_pretrained(model_name) | |
| model.eval() | |
| image_path = '/path/' | |
| image = Image.open(image_path).convert('RGB') | |
| inputs = feature_extractor(images=image, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| predicted_class_idx = logits.argmax(-1).item() | |
| classes = model.config.id2label | |
| print(f"Predicted class: {classes[predicted_class_idx]}") |