EGAI-vision-encoder / v1 /inference.py
ANISH-j's picture
Upload 9 files
63843fb verified
import tkinter as tk
from tkinter import filedialog, messagebox
from PIL import Image, ImageTk
import torch
from transformers import AutoImageProcessor, SiglipForImageClassification
# Model and processor paths (adjust if needed; assumes final model saved in output_dir)
model_path = "./siglip2_finetuned" # Or "./siglip2_finetuned/checkpoint-1284" if using a specific checkpoint
# Load processor and model
processor = AutoImageProcessor.from_pretrained(model_path)
model = SiglipForImageClassification.from_pretrained(model_path)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
# Get label mappings from model config
id2label = model.config.id2label
# Tkinter GUI
class ImageClassifierApp:
def __init__(self, root):
self.root = root
self.root.title("SigLIP2 Gardner Grading Classifier")
self.root.geometry("600x600")
# Label for instructions
self.instruction_label = tk.Label(root, text="Select an image to classify")
self.instruction_label.pack(pady=10)
# Button to load image
self.load_button = tk.Button(root, text="Load Image", command=self.load_image)
self.load_button.pack(pady=10)
# Canvas to display image
self.image_canvas = tk.Canvas(root, width=400, height=400, bg="white")
self.image_canvas.pack(pady=10)
# Label to display prediction
self.prediction_label = tk.Label(root, text="", font=("Arial", 14))
self.prediction_label.pack(pady=10)
def load_image(self):
file_path = filedialog.askopenfilename(filetypes=[("Image files", "*.png *.jpg *.jpeg")])
if file_path:
try:
# Open and display image
img = Image.open(file_path)
img_resized = img.resize((400, 400))
self.photo_img = ImageTk.PhotoImage(img_resized)
self.image_canvas.create_image(200, 200, image=self.photo_img)
# Preprocess and infer
inputs = processor(images=img, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_id = logits.argmax(-1).item()
predicted_label = id2label[predicted_id]
# Display prediction
self.prediction_label.config(text=f"Predicted Grade: {predicted_label}")
except Exception as e:
messagebox.showerror("Error", f"Failed to process image: {str(e)}")
if __name__ == "__main__":
root = tk.Tk()
app = ImageClassifierApp(root)
root.mainloop()