LeafScan / app.py
A7md47's picture
Update app.py
968b547 verified
Raw
History Blame Contribute Delete
3.79 kB
import streamlit as st
import numpy as np
import json
import os
from PIL import Image
import tensorflow as tf
# ── Page config ─────────────────────────────
st.set_page_config(
page_title="LeafScan β€” Plant Disease Detector",
page_icon="🌿",
layout="centered",
)
# ── Load model ─────────────────────────────
@st.cache_resource
def load_model():
model_path = "plant_model.keras"
indices_path = "class_indices.json"
# Debug: list files in current directory
current_dir = os.getcwd()
files = os.listdir(current_dir)
debug_files = f"Files in {current_dir}: {files}"
if not os.path.exists(model_path):
return None, None, None, f"Model not found: {model_path}. {debug_files}"
if not os.path.exists(indices_path):
return None, None, None, f"class_indices.json not found: {indices_path}. {debug_files}"
try:
model = tf.keras.models.load_model(model_path, compile=False)
except Exception as e:
return None, None, None, f"Failed to load model: {e}. {debug_files}"
try:
with open(indices_path, "r") as f:
raw = json.load(f)
except Exception as e:
return None, None, None, f"Failed to load class indices: {e}"
# Handle both class index formats
if all(str(k).isdigit() for k in raw.keys()):
class_indices = {int(k): v for k, v in raw.items()}
else:
class_indices = {v: k for k, v in raw.items()}
input_shape = model.input_shape
return model, class_indices, input_shape, None
# ── Prediction ─────────────────────────────
def predict(model, class_indices, img, input_shape):
# Use model's expected input size
if input_shape and len(input_shape) >= 3:
target_size = (input_shape[1], input_shape[2]) # (height, width)
else:
target_size = (160, 160) # fallback
img = img.convert("RGB").resize(target_size)
arr = np.array(img, dtype=np.float32) / 255.0
arr = np.expand_dims(arr, axis=0)
preds = model.predict(arr, verbose=0)[0]
top5_idx = np.argsort(preds)[::-1][:5]
top5 = [(class_indices.get(i, "Unknown"), float(preds[i]) * 100) for i in top5_idx]
return top5[0][0], top5[0][1], top5
# ── UI ─────────────────────────────
st.title("🌿 LeafScan - Plant Disease Detector")
# Verify XSRF is disabled (optional - can remove after confirming)
st.write("XSRF Protection:", st.get_option("server.enableXsrfProtection"))
model, class_indices, input_shape, error = load_model()
if error:
st.error(error)
st.stop()
st.success(f"Model loaded β€” {len(class_indices)} classes")
st.info(f"Model input shape: {input_shape}")
uploaded = st.file_uploader("Upload a leaf image", type=["jpg", "png", "jpeg"])
if uploaded:
img = Image.open(uploaded)
st.image(img, caption="Uploaded Image", use_column_width=True)
if st.button("Analyze"):
try:
with st.spinner("Predicting..."):
class_name, conf, top5 = predict(model, class_indices, img, input_shape)
plant, disease = class_name.split("___") if "___" in class_name else ("Unknown", class_name)
st.subheader("Result")
st.write("🌱 Plant:", plant.replace("_", " "))
st.write("🦠 Disease:", disease.replace("_", " "))
st.write(f"🎯 Confidence: {conf:.2f}%")
st.subheader("Top Predictions")
for name, p in top5:
st.write(f"{name}: {p:.2f}%")
except Exception as e:
st.error(f"Prediction failed: {e}")