Gyimah3's picture
Update app.py
51dc7db verified
import gradio as gr
import numpy as np
from PIL import Image
from tensorflow.keras import models
from tensorflow.keras.preprocessing.image import img_to_array
import matplotlib.pyplot as plt
import io
import base64
def inference(image, model_choice):
label_map = {'cassava-healthy': 0, 'cassava-not-healthy:bacteria blight': 1}
inverse_map = {v: k for k, v in label_map.items()}
image = Image.fromarray(image.astype('uint8'), 'RGB')
image = image.resize((64, 64))
image_arr = img_to_array(image)
image_arr /= 255
image_arr = image_arr[np.newaxis, :]
if model_choice == "Cassava Model πŸƒ":
model = models.load_model("cassava_model.keras")
else:
model = models.load_model("/content/maize_model.keras")
proba = model.predict(image_arr)
label = (proba > 0.5).squeeze().astype(int)
result = {
"label": inverse_map.get(int(label)),
"probability": float(proba.squeeze())
}
# Create visualization
fig, ax = plt.subplots(figsize=(8, 6))
ax.bar(['Healthy 🌿', 'Not Healthy πŸ‚'], [1 - result['probability'], result['probability']], color=['#2ecc71', '#e74c3c'])
ax.set_ylim(0, 1)
ax.set_ylabel('Probability')
ax.set_title('Plant Health Prediction πŸ”', fontsize=16, fontweight='bold')
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.tight_layout()
# Convert plot to image
buf = io.BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
plot_image = Image.open(buf)
return result["label"], f"{result['probability']:.2%} of illness(bacteria blight)", plot_image
# Custom CSS for styling
custom_css = """
#component-0 {
max-width: 730px;
margin: auto;
padding: 1.5rem;
border-radius: 10px;
background: linear-gradient(135deg, #f6d365 0%, #fda085 100%);
box-shadow: 0 10px 20px rgba(0,0,0,0.19), 0 6px 6px rgba(0,0,0,0.23);
}
#component-1 {
border-radius: 10px;
overflow: hidden;
}
#component-5 {
border-radius: 10px;
overflow: hidden;
box-shadow: 0 4px 6px rgba(0,0,0,0.1);
}
.label {
font-size: 18px !important;
color: #2c3e50;
font-weight: bold;
}
.output-class {
font-size: 24px !important;
color: #2980b9;
font-weight: bold;
}
.output-prob {
font-size: 20px !important;
color: #16a085;
}
"""
# Gradio interface
with gr.Blocks(css=custom_css) as demo:
gr.Markdown("# 🌱 Crop Diseases Detector πŸ•΅οΈβ€β™‚οΈ")
gr.Markdown("Upload an image of a cassava plant and let's check its health!")
with gr.Row():
input_image = gr.Image(type="numpy", label="πŸ“Έ Upload or Capture Image")
output_image = gr.Image(type="pil", label="πŸ–ΌοΈ Health Prediction Visualization")
model_choice = gr.Dropdown(["Cassava Model πŸƒ"], label="πŸ€– Select Model", value="Cassava Model πŸƒ")
with gr.Row():
detect_btn = gr.Button("πŸ” Detect Plant Health", variant="primary")
output_label = gr.Textbox(label="🏷️ Diagnosis")
output_confidence = gr.Textbox(label="πŸ“Š Confidence")
detect_btn.click(
inference,
inputs=[input_image, model_choice],
outputs=[output_label, output_confidence, output_image]
)
gr.Markdown("## How to use:")
gr.Markdown("1. πŸ“€ Upload an image or πŸ“Έ take a picture of a cassava plant")
gr.Markdown("2. πŸ€– Select the model you want to use")
gr.Markdown("3. πŸ” Click 'Detect Plant Health' to get the results")
gr.Markdown("4. πŸ“Š View the diagnosis, confidence score, and health prediction chart")
demo.launch(debug=True)