rice / app.py
hardin009's picture
Update app.py
d1ea450 verified
import gradio as gr
import tensorflow as tf
import numpy as np
from PIL import Image
# Load the trained model
# Note: Ensure "rice_stage_model.keras" is in the exact same directory as this script
model = tf.keras.models.load_model("rice_stage_model.keras")
# Define class labels mapping to your model's outputs
classes = ["flowering", "germination", "noise", "tillering"]
# Set the confidence threshold to prevent false positives
CONFIDENCE_THRESHOLD = 0.6
def predict_stage(image):
if image is None:
return "⚠️ Please upload an image first."
# 1. Convert to RGB to drop alpha channels (crucial for PNG compatibility)
image = image.convert("RGB")
# 2. Resize to match the model's expected input shape
image = image.resize((224, 224))
# 3. Preprocess the image array (normalize pixel values)
img = np.array(image) / 255.0
img = np.expand_dims(img, axis=0)
# 4. Run prediction
predictions = model.predict(img)[0]
# 5. Extract results
predicted_index = np.argmax(predictions)
confidence = float(predictions[predicted_index])
predicted_label = classes[predicted_index]
# 6. Handle edge cases (noise class or low confidence)
if predicted_label == "noise":
return "⚠️ Invalid image. Please upload a clear rice crop image."
if confidence < CONFIDENCE_THRESHOLD:
return "⚠️ Model uncertain. Please upload a clearer crop photo."
# 7. Return the final formatted string
return f"🌾 Predicted Stage: {predicted_label.capitalize()} | Confidence: {confidence:.2f}"
# Set up the Gradio interface
# Set up the Gradio interface
interface = gr.Interface(
fn=predict_stage,
inputs=gr.Image(type="pil"),
outputs="text",
title="Rice Crop Growth Stage Classifier",
description="Upload a picture of the rice crop from the field to instantly detect its growth stage.",
flagging_mode="never" # Updated from allow_flagging for Gradio 4 compatibility
)
if __name__ == "__main__":
# Launching directly prevents Gradio's buggy auto-reloader from crashing the app
interface.launch(
server_name="0.0.0.0",
server_port=7860,
ssr_mode=False,
show_error=True
)