File size: 3,602 Bytes
3e4e43d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dda996b
 
 
 
 
 
 
 
 
3e4e43d
dda996b
 
 
 
 
 
 
 
 
3e4e43d
 
 
2e2833c
3e4e43d
33d7bd6
3e4e43d
dda996b
3e4e43d
 
 
dda996b
3e4e43d
 
 
 
 
 
048fc26
3e4e43d
 
 
048fc26
3e4e43d
048fc26
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import gradio as gr
import numpy as np
from PIL import Image
import tensorflow as tf
import logging
import os

# Set up logging to capture errors
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

logger.info("Starting application initialization")

# Load class names from class_indices.txt
try:
    logger.info("Loading class names from class_indices.txt")
    with open("class_indices.txt", "r") as f:
        class_names = [line.strip() for line in f.readlines() if line.strip()]
    logger.info(f"Class names loaded: {class_names}")
except Exception as e:
    logger.error(f"Failed to load class_indices.txt: {str(e)}")
    raise Exception(f"Failed to load class_indices.txt: {str(e)}")

# Load the Keras .h5 model
try:
    logger.info("Loading model: garbage_classifier.h5")
    model = tf.keras.models.load_model("garbage_classifier.h5", compile=False)
    logger.info("Model loaded successfully")
except Exception as e:
    logger.error(f"Failed to load model: {str(e)}")
    raise Exception(f"Failed to load model: {str(e)}")

def predict_image(image: Image.Image):
    try:
        logger.info("Processing image for prediction")
        # Preprocess the image
        img = image.convert("RGB").resize((128, 128))
        arr = np.array(img).astype("float32") / 255.0
        arr = np.expand_dims(arr, axis=0)

        # Make prediction
        preds = model.predict(arr)[0]
        logger.info("Prediction completed")

        # Get predicted class and confidence
        pred_class_idx = np.argmax(preds)
        pred_class = class_names[pred_class_idx]
        confidence = float(preds[pred_class_idx])

        # Create HTML for per-class probabilities with progress bars
        prob_html = ""
        for i in range(len(class_names)):
            class_name = class_names[i]
            prob = float(preds[i])
            prob_html += f'<div style="display: flex; align-items: center; margin-bottom: 10px;">'
            prob_html += f'<span style="width: 100px; font-weight: bold;">{class_name}: {prob:.2f}</span>'
            prob_html += f'<progress value="{prob}" max="1" style="width: 200px; margin-left: 10px; height: 20px;"></progress>'
            prob_html += f'</div>'

        # Return formatted HTML output with titles
        output = f"""
        <div style="font-family: Arial, sans-serif;">
            <h3>Prediction Class: {pred_class}</h3>
            <h3>Confidence: {confidence:.2f}</h3>
            <h3>Per-class probabilities</h3>
            {prob_html}
        </div>
        """
        return output
    except Exception as e:
        logger.error(f"Error during prediction: {str(e)}")
        return f"Error during prediction: {str(e)}"

# Create Gradio interface with custom CSS
logger.info("Initializing Gradio interface")
with gr.Blocks(css=".large-output { height: 400px !important; overflow: visible !important; }") as iface:
    gr.Markdown("# ♻️ Garbage Classification Application")
    gr.Markdown("Upload an image of garbage (plastic, organic, or metal):")
    img_input = gr.Image(type="pil")
    output = gr.HTML(label="Prediction Result", elem_classes="large-output")
    gr.Button("Classify").click(
        fn=predict_image,
        inputs=img_input,
        outputs=output
    )

# Queue the interface for serving
iface.queue(api_open=False)  # Disable public API for simplicity
logger.info("Gradio interface queued and ready for serving")

# Start Gradio server explicitly
if __name__ == "__main__":
    logger.info("Starting Gradio server")
    iface.launch(server_name="0.0.0.0", server_port=7860, share=False)