import tkinter as tk from tkinter import filedialog, messagebox from PIL import Image, ImageTk import torch from transformers import AutoImageProcessor, SiglipForImageClassification # Model and processor paths (adjust if needed; assumes final model saved in output_dir) model_path = "./siglip2_finetuned" # Or "./siglip2_finetuned/checkpoint-1284" if using a specific checkpoint # Load processor and model processor = AutoImageProcessor.from_pretrained(model_path) model = SiglipForImageClassification.from_pretrained(model_path) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) model.eval() # Get label mappings from model config id2label = model.config.id2label # Tkinter GUI class ImageClassifierApp: def __init__(self, root): self.root = root self.root.title("SigLIP2 Gardner Grading Classifier") self.root.geometry("600x600") # Label for instructions self.instruction_label = tk.Label(root, text="Select an image to classify") self.instruction_label.pack(pady=10) # Button to load image self.load_button = tk.Button(root, text="Load Image", command=self.load_image) self.load_button.pack(pady=10) # Canvas to display image self.image_canvas = tk.Canvas(root, width=400, height=400, bg="white") self.image_canvas.pack(pady=10) # Label to display prediction self.prediction_label = tk.Label(root, text="", font=("Arial", 14)) self.prediction_label.pack(pady=10) def load_image(self): file_path = filedialog.askopenfilename(filetypes=[("Image files", "*.png *.jpg *.jpeg")]) if file_path: try: # Open and display image img = Image.open(file_path) img_resized = img.resize((400, 400)) self.photo_img = ImageTk.PhotoImage(img_resized) self.image_canvas.create_image(200, 200, image=self.photo_img) # Preprocess and infer inputs = processor(images=img, return_tensors="pt").to(device) with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits predicted_id = logits.argmax(-1).item() predicted_label = id2label[predicted_id] # Display prediction self.prediction_label.config(text=f"Predicted Grade: {predicted_label}") except Exception as e: messagebox.showerror("Error", f"Failed to process image: {str(e)}") if __name__ == "__main__": root = tk.Tk() app = ImageClassifierApp(root) root.mainloop()