shashankbodapati's picture
Update app.py
bb9b190
import gradio as gr
import tensorflow as tf
from PIL import Image
import numpy as np
from PIL import Image
# Load the model outside of the prediction function to avoid reloading it on each request
MODEL = tf.keras.models.load_model('my_final_model.h5')
CLASS_NAMES = ['Potato_Early_blight', 'Potato_Late_blight', 'Potato_healthy']
def classify_image(image_input):
# Convert the NumPy array to a PIL Image
image_input = Image.fromarray(image_input.astype('uint8'), 'RGB')
# Resize the image
image_processed = image_input.resize((256, 256))
# Convert the PIL Image back to NumPy array and preprocess it
image_processed = np.array(image_processed) / 255.0
image_processed = image_processed.reshape((1, 256, 256, 3))
# Make a prediction
predictions = MODEL.predict(image_processed)
prediction_result = {CLASS_NAMES[i]: float(predictions[0][i]) for i in range(len(CLASS_NAMES))}
# Prepare the markdown information
blight_info = """
### Late Blight in Potatoes
**Pathogen:** Fungus
**Hosts:** Potatoes
...
For more information and updates on potato blight strains, visit [EuroBlight](www.euroblight.net).
"""
return prediction_result, blight_info
def main():
# Define Gradio interface components
image_input = gr.Image(label="Upload a Potato Leaf Image")
outputs = gr.Label(num_top_classes=3, label="Predictions")
blight_info_markdown = gr.Markdown()
# Create Gradio interface
interface = gr.Interface(
fn=classify_image,
inputs=image_input,
outputs=[outputs, blight_info_markdown],
title="Potato Leaf Disease Detection",
description="Upload an image of a potato leaf, and the model will predict if it is healthy or has early blight or late blight."
)
# Launch the interface
interface.launch(share=True)
if __name__ == "__main__":
main()