Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import tensorflow as tf | |
| import numpy as np | |
| import re | |
| import base64 | |
| import io | |
| from PIL import Image | |
| # Load the saved model | |
| model = tf.keras.models.load_model('banana_disease_densenet121.keras') | |
| # Load class names | |
| class_names = np.load("class_names.npy", allow_pickle=True) | |
| # Preprocess the input image | |
| def preprocess_image(img): | |
| if img is None: | |
| return None | |
| img = img.resize((256, 256)) # Resize to match training | |
| img = np.array(img) / 255.0 # Normalize | |
| img = np.expand_dims(img, axis=0) # Add batch dimension | |
| return img | |
| # Prediction function | |
| # Prediction function | |
| def predict_disease(img): | |
| """ | |
| Predict banana disease from image. | |
| Handles both PIL images and base64 strings from Flutter. | |
| Returns filtered results (only Panama & Healthy) | |
| """ | |
| try: | |
| # Handle base64 encoded images from Flutter | |
| if isinstance(img, str): | |
| # Remove data URL prefix if present (e.g., "data:image/jpeg;base64,") | |
| if img.startswith('data:image'): | |
| img = re.sub(r'^data:image/.+;base64,', '', img) | |
| # Decode base64 to image | |
| image_data = base64.b64decode(img) | |
| img = Image.open(io.BytesIO(image_data)) | |
| # Ensure RGB mode | |
| if img.mode != 'RGB': | |
| img = img.convert('RGB') | |
| # Preprocess the image | |
| img_processed = preprocess_image(img) | |
| if img_processed is None: | |
| return "⚠️ No image provided", {} | |
| # Predict | |
| predictions = model.predict(img_processed)[0] | |
| predicted_class = np.argmax(predictions) | |
| # Filter to show only Panama Disease and Healthy Leaf | |
| filtered_classes = ['Banana Panama Disease', 'Banana Healthy Leaf'] | |
| confidence_scores = { | |
| class_name: float(predictions[i]) | |
| for i, class_name in enumerate(class_names) | |
| if class_name in filtered_classes | |
| } | |
| return f"Predicted: {class_names[predicted_class]}", confidence_scores | |
| except Exception as e: | |
| print(f"Error in prediction: {str(e)}") | |
| return f"Error: {str(e)}", {} | |
| demo = gr.Interface( | |
| fn=predict_disease, | |
| inputs=gr.Image(type="pil", label="Upload Banana Leaf"), | |
| outputs=[ | |
| gr.Text(label="Prediction"), | |
| gr.Label(label="Confidence Scores", num_top_classes=2) | |
| ], | |
| title="🍌 Banana Leaf Disease Classifier", | |
| description="Upload a banana leaf image, and our AI will diagnose the disease", | |
| theme=gr.themes.Soft(), | |
| api_name="predict" | |
| ) | |
| # Launch app direct | |
| demo.launch() | |