Yasiru Chamuditha
fix issues
dda996b
import gradio as gr
import numpy as np
from PIL import Image
import tensorflow as tf
import logging
import os
# Set up logging to capture errors
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
logger.info("Starting application initialization")
# Load class names from class_indices.txt
try:
logger.info("Loading class names from class_indices.txt")
with open("class_indices.txt", "r") as f:
class_names = [line.strip() for line in f.readlines() if line.strip()]
logger.info(f"Class names loaded: {class_names}")
except Exception as e:
logger.error(f"Failed to load class_indices.txt: {str(e)}")
raise Exception(f"Failed to load class_indices.txt: {str(e)}")
# Load the Keras .h5 model
try:
logger.info("Loading model: garbage_classifier.h5")
model = tf.keras.models.load_model("garbage_classifier.h5", compile=False)
logger.info("Model loaded successfully")
except Exception as e:
logger.error(f"Failed to load model: {str(e)}")
raise Exception(f"Failed to load model: {str(e)}")
def predict_image(image: Image.Image):
try:
logger.info("Processing image for prediction")
# Preprocess the image
img = image.convert("RGB").resize((128, 128))
arr = np.array(img).astype("float32") / 255.0
arr = np.expand_dims(arr, axis=0)
# Make prediction
preds = model.predict(arr)[0]
logger.info("Prediction completed")
# Get predicted class and confidence
pred_class_idx = np.argmax(preds)
pred_class = class_names[pred_class_idx]
confidence = float(preds[pred_class_idx])
# Create HTML for per-class probabilities with progress bars
prob_html = ""
for i in range(len(class_names)):
class_name = class_names[i]
prob = float(preds[i])
prob_html += f'<div style="display: flex; align-items: center; margin-bottom: 10px;">'
prob_html += f'<span style="width: 100px; font-weight: bold;">{class_name}: {prob:.2f}</span>'
prob_html += f'<progress value="{prob}" max="1" style="width: 200px; margin-left: 10px; height: 20px;"></progress>'
prob_html += f'</div>'
# Return formatted HTML output with titles
output = f"""
<div style="font-family: Arial, sans-serif;">
<h3>Prediction Class: {pred_class}</h3>
<h3>Confidence: {confidence:.2f}</h3>
<h3>Per-class probabilities</h3>
{prob_html}
</div>
"""
return output
except Exception as e:
logger.error(f"Error during prediction: {str(e)}")
return f"Error during prediction: {str(e)}"
# Create Gradio interface with custom CSS
logger.info("Initializing Gradio interface")
with gr.Blocks(css=".large-output { height: 400px !important; overflow: visible !important; }") as iface:
gr.Markdown("# ♻️ Garbage Classification Application")
gr.Markdown("Upload an image of garbage (plastic, organic, or metal):")
img_input = gr.Image(type="pil")
output = gr.HTML(label="Prediction Result", elem_classes="large-output")
gr.Button("Classify").click(
fn=predict_image,
inputs=img_input,
outputs=output
)
# Queue the interface for serving
iface.queue(api_open=False) # Disable public API for simplicity
logger.info("Gradio interface queued and ready for serving")
# Start Gradio server explicitly
if __name__ == "__main__":
logger.info("Starting Gradio server")
iface.launch(server_name="0.0.0.0", server_port=7860, share=False)