import os import torch import torch.nn as nn import torchvision.transforms as transforms import timm # EfficientNet library import gradio as gr from PIL import Image # 🔹 Set device (CPU since no GPU) device = torch.device("cpu") # 🔹 Load EfficientNet-B0 Model model = timm.create_model("efficientnet_b0", pretrained=False, num_classes=2) # Binary Classification model.load_state_dict(torch.load("efficientnet_b0_model.pth", map_location=device)) # Load trained weights model.eval() model.to(device) # 🔹 Define Data Transforms (Same as Training) transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # Standard normalization ]) # 🔹 Paths to example images correct_examples = [ # "dataset/correct/correct (1).webp", "dataset/correct/correct (2).webp", # "dataset/correct/correct (3).jpeg", # "dataset/correct/correct (4).jpeg", "dataset/correct/correct (5).webp", # "dataset/correct/correct (6).jpeg" ] incorrect_examples = [ "dataset/incorrect/incorrect (1).webp", "dataset/incorrect/incorrect (3).webp", # "dataset/incorrect/incorrect (10).webp", # "dataset/incorrect/incorrect (12).webp", # "dataset/incorrect/incorrect (15).webp", # "dataset/incorrect/incorrect (16).webp" ] # 🔹 Function to classify an image def classify_installation(image): """Classify if the bed installation is correct or incorrect using EfficientNet-B0.""" image = transform(image).unsqueeze(0).to(device) # Apply transforms and add batch dimension with torch.no_grad(): outputs = model(image) _, predicted = torch.max(outputs, 1) # Get class with highest probability return "✅ Correct Installation" if predicted.item() == 0 else "❌ Incorrect Installation" # 🔹 Function to load image from file path def load_image(image_path): return Image.open(image_path).convert("RGB") # 🔹 Function to process selected example image def process_example(image_path): image = load_image(image_path) return classify_installation(image), image # 🔹 Gradio UI with gr.Blocks() as demo: gr.Markdown("# Installation Classifier") gr.Markdown("Upload an image or select one from the examples below to check if the bed installation is correct.") with gr.Row(): uploaded_image = gr.Image(type="pil", label="Upload an image for testing") output_text = gr.Textbox(label="Result") gr.Markdown("### Check Installations Examples (Click Button to Classify)") with gr.Row(): correct_buttons = [] for i, img_path in enumerate(correct_examples): with gr.Column(): gr.Image(value=load_image(img_path), interactive=False, width=150, height=150) btn = gr.Button(value="Check Accuracy") correct_buttons.append((btn, img_path)) with gr.Row(): incorrect_buttons = [] for i, img_path in enumerate(incorrect_examples): with gr.Column(): gr.Image(value=load_image(img_path), interactive=False, width=150, height=150) btn = gr.Button(value="Check Accuracy") incorrect_buttons.append((btn, img_path)) # 🔹 Connect buttons to classification function for btn, img_path in correct_buttons + incorrect_buttons: btn.click(fn=process_example, inputs=[gr.State(img_path)], outputs=[output_text, uploaded_image]) # 🔹 Process uploaded image uploaded_image.change(fn=classify_installation, inputs=[uploaded_image], outputs=output_text) # 🔹 Launch Gradio UI demo.launch()