Vit-Flower / example_inference.py
Acras's picture
Upload 11 files
439df6b verified
#!/usr/bin/env python
# -*- coding:utf-8 -*-
import torch
from PIL import Image
from transformers import AutoImageProcessor, AutoModelForImageClassification
import os
def main():
model_dir = os.path.dirname(os.path.abspath(__file__))
print("Loading model and processor...")
processor = AutoImageProcessor.from_pretrained(model_dir)
model = AutoModelForImageClassification.from_pretrained(model_dir)
model.eval()
test_image_path = os.path.join(model_dir, "assets", "164.jpg")
if not os.path.exists(test_image_path):
print(f"Test image not found: {test_image_path}")
print("Please add a test image to assets/ folder")
return
print(f"Loading image: {test_image_path}")
image = Image.open(test_image_path).convert("RGB")
print("Processing image...")
inputs = processor(images=image, return_tensors="pt")
print("Running inference...")
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probs = torch.nn.functional.softmax(logits, dim=-1)
top5 = torch.topk(probs[0], k=5)
print("\nTop 5 Predictions:")
print("-" * 60)
for i, (prob, idx) in enumerate(zip(top5.values, top5.indices)):
class_idx = idx.item()
label = model.config.id2label[str(class_idx)]
confidence = prob.item() * 100
print(f"{i+1}. {label} | Confidence: {confidence:.2f}%")
if __name__ == "__main__":
main()