chemoiko's picture
Update app.py
72a6a28 verified
import gradio as gr
import tensorflow as tf
import numpy as np
import re
import base64
import io
from PIL import Image
# Load the saved model
model = tf.keras.models.load_model('banana_disease_densenet121.keras')
# Load class names
class_names = np.load("class_names.npy", allow_pickle=True)
# Preprocess the input image
def preprocess_image(img):
if img is None:
return None
img = img.resize((256, 256)) # Resize to match training
img = np.array(img) / 255.0 # Normalize
img = np.expand_dims(img, axis=0) # Add batch dimension
return img
# Prediction function
# Prediction function
def predict_disease(img):
"""
Predict banana disease from image.
Handles both PIL images and base64 strings from Flutter.
Returns filtered results (only Panama & Healthy)
"""
try:
# Handle base64 encoded images from Flutter
if isinstance(img, str):
# Remove data URL prefix if present (e.g., "data:image/jpeg;base64,")
if img.startswith('data:image'):
img = re.sub(r'^data:image/.+;base64,', '', img)
# Decode base64 to image
image_data = base64.b64decode(img)
img = Image.open(io.BytesIO(image_data))
# Ensure RGB mode
if img.mode != 'RGB':
img = img.convert('RGB')
# Preprocess the image
img_processed = preprocess_image(img)
if img_processed is None:
return "⚠️ No image provided", {}
# Predict
predictions = model.predict(img_processed)[0]
predicted_class = np.argmax(predictions)
# Filter to show only Panama Disease and Healthy Leaf
filtered_classes = ['Banana Panama Disease', 'Banana Healthy Leaf']
confidence_scores = {
class_name: float(predictions[i])
for i, class_name in enumerate(class_names)
if class_name in filtered_classes
}
return f"Predicted: {class_names[predicted_class]}", confidence_scores
except Exception as e:
print(f"Error in prediction: {str(e)}")
return f"Error: {str(e)}", {}
demo = gr.Interface(
fn=predict_disease,
inputs=gr.Image(type="pil", label="Upload Banana Leaf"),
outputs=[
gr.Text(label="Prediction"),
gr.Label(label="Confidence Scores", num_top_classes=2)
],
title="🍌 Banana Leaf Disease Classifier",
description="Upload a banana leaf image, and our AI will diagnose the disease",
theme=gr.themes.Soft(),
api_name="predict"
)
# Launch app direct
demo.launch()