Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import numpy as np | |
| from PIL import Image | |
| import tensorflow as tf | |
| import logging | |
| import os | |
| # Set up logging to capture errors | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| logger.info("Starting application initialization") | |
| # Load class names from class_indices.txt | |
| try: | |
| logger.info("Loading class names from class_indices.txt") | |
| with open("class_indices.txt", "r") as f: | |
| class_names = [line.strip() for line in f.readlines() if line.strip()] | |
| logger.info(f"Class names loaded: {class_names}") | |
| except Exception as e: | |
| logger.error(f"Failed to load class_indices.txt: {str(e)}") | |
| raise Exception(f"Failed to load class_indices.txt: {str(e)}") | |
| # Load the Keras .h5 model | |
| try: | |
| logger.info("Loading model: garbage_classifier.h5") | |
| model = tf.keras.models.load_model("garbage_classifier.h5", compile=False) | |
| logger.info("Model loaded successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to load model: {str(e)}") | |
| raise Exception(f"Failed to load model: {str(e)}") | |
| def predict_image(image: Image.Image): | |
| try: | |
| logger.info("Processing image for prediction") | |
| # Preprocess the image | |
| img = image.convert("RGB").resize((128, 128)) | |
| arr = np.array(img).astype("float32") / 255.0 | |
| arr = np.expand_dims(arr, axis=0) | |
| # Make prediction | |
| preds = model.predict(arr)[0] | |
| logger.info("Prediction completed") | |
| # Get predicted class and confidence | |
| pred_class_idx = np.argmax(preds) | |
| pred_class = class_names[pred_class_idx] | |
| confidence = float(preds[pred_class_idx]) | |
| # Create HTML for per-class probabilities with progress bars | |
| prob_html = "" | |
| for i in range(len(class_names)): | |
| class_name = class_names[i] | |
| prob = float(preds[i]) | |
| prob_html += f'<div style="display: flex; align-items: center; margin-bottom: 10px;">' | |
| prob_html += f'<span style="width: 100px; font-weight: bold;">{class_name}: {prob:.2f}</span>' | |
| prob_html += f'<progress value="{prob}" max="1" style="width: 200px; margin-left: 10px; height: 20px;"></progress>' | |
| prob_html += f'</div>' | |
| # Return formatted HTML output with titles | |
| output = f""" | |
| <div style="font-family: Arial, sans-serif;"> | |
| <h3>Prediction Class: {pred_class}</h3> | |
| <h3>Confidence: {confidence:.2f}</h3> | |
| <h3>Per-class probabilities</h3> | |
| {prob_html} | |
| </div> | |
| """ | |
| return output | |
| except Exception as e: | |
| logger.error(f"Error during prediction: {str(e)}") | |
| return f"Error during prediction: {str(e)}" | |
| # Create Gradio interface with custom CSS | |
| logger.info("Initializing Gradio interface") | |
| with gr.Blocks(css=".large-output { height: 400px !important; overflow: visible !important; }") as iface: | |
| gr.Markdown("# ♻️ Garbage Classification Application") | |
| gr.Markdown("Upload an image of garbage (plastic, organic, or metal):") | |
| img_input = gr.Image(type="pil") | |
| output = gr.HTML(label="Prediction Result", elem_classes="large-output") | |
| gr.Button("Classify").click( | |
| fn=predict_image, | |
| inputs=img_input, | |
| outputs=output | |
| ) | |
| # Queue the interface for serving | |
| iface.queue(api_open=False) # Disable public API for simplicity | |
| logger.info("Gradio interface queued and ready for serving") | |
| # Start Gradio server explicitly | |
| if __name__ == "__main__": | |
| logger.info("Starting Gradio server") | |
| iface.launch(server_name="0.0.0.0", server_port=7860, share=False) |