File size: 3,676 Bytes
3182e2f
 
3c4d02b
 
 
3182e2f
 
 
3c4d02b
 
3182e2f
3c4d02b
 
 
 
 
3182e2f
3c4d02b
 
 
 
 
 
3182e2f
3c4d02b
3182e2f
3c4d02b
3182e2f
3c4d02b
 
 
 
3182e2f
 
 
 
a95549a
3c4d02b
 
 
 
3182e2f
 
3c4d02b
3182e2f
3c4d02b
 
 
3182e2f
3c4d02b
 
3182e2f
3c4d02b
3182e2f
3c4d02b
3182e2f
 
 
3c4d02b
3182e2f
 
 
 
3c4d02b
3182e2f
4d5f5da
3182e2f
 
 
 
 
 
3c4d02b
3182e2f
 
 
 
 
3c4d02b
3182e2f
 
 
 
 
 
 
3c4d02b
3182e2f
 
3c4d02b
3182e2f
 
 
3c4d02b
3182e2f
 
3c4d02b
3182e2f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import os
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import timm  # EfficientNet library
import gradio as gr
from PIL import Image

# ๐Ÿ”น Set device (CPU since no GPU)
device = torch.device("cpu")

# ๐Ÿ”น Load EfficientNet-B0 Model
model = timm.create_model("efficientnet_b0", pretrained=False, num_classes=2)  # Binary Classification
model.load_state_dict(torch.load("efficientnet_b0_model.pth", map_location=device))  # Load trained weights
model.eval()
model.to(device)

# ๐Ÿ”น Define Data Transforms (Same as Training)
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # Standard normalization
])

# ๐Ÿ”น Paths to example images
correct_examples = [
    # "dataset/correct/correct (1).webp",
    "dataset/correct/correct (2).webp",
    # "dataset/correct/correct (3).jpeg",
    # "dataset/correct/correct (4).jpeg",
    "dataset/correct/correct (5).webp",
    # "dataset/correct/correct (6).jpeg"
]

incorrect_examples = [
    "dataset/incorrect/incorrect (1).webp",
    "dataset/incorrect/incorrect (3).webp",
    # "dataset/incorrect/incorrect (10).webp",
    # "dataset/incorrect/incorrect (12).webp",
    # "dataset/incorrect/incorrect (15).webp",
    # "dataset/incorrect/incorrect (16).webp"
]

# ๐Ÿ”น Function to classify an image
def classify_installation(image):
    """Classify if the bed installation is correct or incorrect using EfficientNet-B0."""
    image = transform(image).unsqueeze(0).to(device)  # Apply transforms and add batch dimension
    
    with torch.no_grad():
        outputs = model(image)
        _, predicted = torch.max(outputs, 1)  # Get class with highest probability
    
    return "โœ… Correct Installation" if predicted.item() == 0 else "โŒ Incorrect Installation"

# ๐Ÿ”น Function to load image from file path
def load_image(image_path):
    return Image.open(image_path).convert("RGB")

# ๐Ÿ”น Function to process selected example image
def process_example(image_path):
    image = load_image(image_path)
    return classify_installation(image), image

# ๐Ÿ”น Gradio UI
with gr.Blocks() as demo:
    gr.Markdown("# Installation Classifier")
    gr.Markdown("Upload an image or select one from the examples below to check if the bed installation is correct.")

    with gr.Row():
        uploaded_image = gr.Image(type="pil", label="Upload an image for testing")
        output_text = gr.Textbox(label="Result")

    gr.Markdown("### Check Installations Examples (Click Button to Classify)")
    with gr.Row():
        correct_buttons = []
        for i, img_path in enumerate(correct_examples):
            with gr.Column():
                gr.Image(value=load_image(img_path), interactive=False, width=150, height=150)
                btn = gr.Button(value="Check Accuracy")
                correct_buttons.append((btn, img_path))

    with gr.Row():
        incorrect_buttons = []
        for i, img_path in enumerate(incorrect_examples):
            with gr.Column():
                gr.Image(value=load_image(img_path), interactive=False, width=150, height=150)
                btn = gr.Button(value="Check Accuracy")
                incorrect_buttons.append((btn, img_path))

    # ๐Ÿ”น Connect buttons to classification function
    for btn, img_path in correct_buttons + incorrect_buttons:
        btn.click(fn=process_example, inputs=[gr.State(img_path)], outputs=[output_text, uploaded_image])

    # ๐Ÿ”น Process uploaded image
    uploaded_image.change(fn=classify_installation, inputs=[uploaded_image], outputs=output_text)

# ๐Ÿ”น Launch Gradio UI
demo.launch()