Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import torch.nn as nn | |
| import torchvision.transforms as transforms | |
| import timm # EfficientNet library | |
| import gradio as gr | |
| from PIL import Image | |
| # ๐น Set device (CPU since no GPU) | |
| device = torch.device("cpu") | |
| # ๐น Load EfficientNet-B0 Model | |
| model = timm.create_model("efficientnet_b0", pretrained=False, num_classes=2) # Binary Classification | |
| model.load_state_dict(torch.load("efficientnet_b0_model.pth", map_location=device)) # Load trained weights | |
| model.eval() | |
| model.to(device) | |
| # ๐น Define Data Transforms (Same as Training) | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # Standard normalization | |
| ]) | |
| # ๐น Paths to example images | |
| correct_examples = [ | |
| # "dataset/correct/correct (1).webp", | |
| "dataset/correct/correct (2).webp", | |
| # "dataset/correct/correct (3).jpeg", | |
| # "dataset/correct/correct (4).jpeg", | |
| "dataset/correct/correct (5).webp", | |
| # "dataset/correct/correct (6).jpeg" | |
| ] | |
| incorrect_examples = [ | |
| "dataset/incorrect/incorrect (1).webp", | |
| "dataset/incorrect/incorrect (3).webp", | |
| # "dataset/incorrect/incorrect (10).webp", | |
| # "dataset/incorrect/incorrect (12).webp", | |
| # "dataset/incorrect/incorrect (15).webp", | |
| # "dataset/incorrect/incorrect (16).webp" | |
| ] | |
| # ๐น Function to classify an image | |
| def classify_installation(image): | |
| """Classify if the bed installation is correct or incorrect using EfficientNet-B0.""" | |
| image = transform(image).unsqueeze(0).to(device) # Apply transforms and add batch dimension | |
| with torch.no_grad(): | |
| outputs = model(image) | |
| _, predicted = torch.max(outputs, 1) # Get class with highest probability | |
| return "โ Correct Installation" if predicted.item() == 0 else "โ Incorrect Installation" | |
| # ๐น Function to load image from file path | |
| def load_image(image_path): | |
| return Image.open(image_path).convert("RGB") | |
| # ๐น Function to process selected example image | |
| def process_example(image_path): | |
| image = load_image(image_path) | |
| return classify_installation(image), image | |
| # ๐น Gradio UI | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Installation Classifier") | |
| gr.Markdown("Upload an image or select one from the examples below to check if the bed installation is correct.") | |
| with gr.Row(): | |
| uploaded_image = gr.Image(type="pil", label="Upload an image for testing") | |
| output_text = gr.Textbox(label="Result") | |
| gr.Markdown("### Check Installations Examples (Click Button to Classify)") | |
| with gr.Row(): | |
| correct_buttons = [] | |
| for i, img_path in enumerate(correct_examples): | |
| with gr.Column(): | |
| gr.Image(value=load_image(img_path), interactive=False, width=150, height=150) | |
| btn = gr.Button(value="Check Accuracy") | |
| correct_buttons.append((btn, img_path)) | |
| with gr.Row(): | |
| incorrect_buttons = [] | |
| for i, img_path in enumerate(incorrect_examples): | |
| with gr.Column(): | |
| gr.Image(value=load_image(img_path), interactive=False, width=150, height=150) | |
| btn = gr.Button(value="Check Accuracy") | |
| incorrect_buttons.append((btn, img_path)) | |
| # ๐น Connect buttons to classification function | |
| for btn, img_path in correct_buttons + incorrect_buttons: | |
| btn.click(fn=process_example, inputs=[gr.State(img_path)], outputs=[output_text, uploaded_image]) | |
| # ๐น Process uploaded image | |
| uploaded_image.change(fn=classify_installation, inputs=[uploaded_image], outputs=output_text) | |
| # ๐น Launch Gradio UI | |
| demo.launch() | |