Tngarg's picture
Update app.py
a95549a verified
import os
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import timm # EfficientNet library
import gradio as gr
from PIL import Image
# ๐Ÿ”น Set device (CPU since no GPU)
device = torch.device("cpu")
# ๐Ÿ”น Load EfficientNet-B0 Model
model = timm.create_model("efficientnet_b0", pretrained=False, num_classes=2) # Binary Classification
model.load_state_dict(torch.load("efficientnet_b0_model.pth", map_location=device)) # Load trained weights
model.eval()
model.to(device)
# ๐Ÿ”น Define Data Transforms (Same as Training)
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # Standard normalization
])
# ๐Ÿ”น Paths to example images
correct_examples = [
# "dataset/correct/correct (1).webp",
"dataset/correct/correct (2).webp",
# "dataset/correct/correct (3).jpeg",
# "dataset/correct/correct (4).jpeg",
"dataset/correct/correct (5).webp",
# "dataset/correct/correct (6).jpeg"
]
incorrect_examples = [
"dataset/incorrect/incorrect (1).webp",
"dataset/incorrect/incorrect (3).webp",
# "dataset/incorrect/incorrect (10).webp",
# "dataset/incorrect/incorrect (12).webp",
# "dataset/incorrect/incorrect (15).webp",
# "dataset/incorrect/incorrect (16).webp"
]
# ๐Ÿ”น Function to classify an image
def classify_installation(image):
"""Classify if the bed installation is correct or incorrect using EfficientNet-B0."""
image = transform(image).unsqueeze(0).to(device) # Apply transforms and add batch dimension
with torch.no_grad():
outputs = model(image)
_, predicted = torch.max(outputs, 1) # Get class with highest probability
return "โœ… Correct Installation" if predicted.item() == 0 else "โŒ Incorrect Installation"
# ๐Ÿ”น Function to load image from file path
def load_image(image_path):
return Image.open(image_path).convert("RGB")
# ๐Ÿ”น Function to process selected example image
def process_example(image_path):
image = load_image(image_path)
return classify_installation(image), image
# ๐Ÿ”น Gradio UI
with gr.Blocks() as demo:
gr.Markdown("# Installation Classifier")
gr.Markdown("Upload an image or select one from the examples below to check if the bed installation is correct.")
with gr.Row():
uploaded_image = gr.Image(type="pil", label="Upload an image for testing")
output_text = gr.Textbox(label="Result")
gr.Markdown("### Check Installations Examples (Click Button to Classify)")
with gr.Row():
correct_buttons = []
for i, img_path in enumerate(correct_examples):
with gr.Column():
gr.Image(value=load_image(img_path), interactive=False, width=150, height=150)
btn = gr.Button(value="Check Accuracy")
correct_buttons.append((btn, img_path))
with gr.Row():
incorrect_buttons = []
for i, img_path in enumerate(incorrect_examples):
with gr.Column():
gr.Image(value=load_image(img_path), interactive=False, width=150, height=150)
btn = gr.Button(value="Check Accuracy")
incorrect_buttons.append((btn, img_path))
# ๐Ÿ”น Connect buttons to classification function
for btn, img_path in correct_buttons + incorrect_buttons:
btn.click(fn=process_example, inputs=[gr.State(img_path)], outputs=[output_text, uploaded_image])
# ๐Ÿ”น Process uploaded image
uploaded_image.change(fn=classify_installation, inputs=[uploaded_image], outputs=output_text)
# ๐Ÿ”น Launch Gradio UI
demo.launch()