import streamlit as st import numpy as np import json import os from PIL import Image import tensorflow as tf # ── Page config ───────────────────────────── st.set_page_config( page_title="LeafScan — Plant Disease Detector", page_icon="🌿", layout="centered", ) # ── Load model ───────────────────────────── @st.cache_resource def load_model(): model_path = "plant_model.keras" indices_path = "class_indices.json" # Debug: list files in current directory current_dir = os.getcwd() files = os.listdir(current_dir) debug_files = f"Files in {current_dir}: {files}" if not os.path.exists(model_path): return None, None, None, f"Model not found: {model_path}. {debug_files}" if not os.path.exists(indices_path): return None, None, None, f"class_indices.json not found: {indices_path}. {debug_files}" try: model = tf.keras.models.load_model(model_path, compile=False) except Exception as e: return None, None, None, f"Failed to load model: {e}. {debug_files}" try: with open(indices_path, "r") as f: raw = json.load(f) except Exception as e: return None, None, None, f"Failed to load class indices: {e}" # Handle both class index formats if all(str(k).isdigit() for k in raw.keys()): class_indices = {int(k): v for k, v in raw.items()} else: class_indices = {v: k for k, v in raw.items()} input_shape = model.input_shape return model, class_indices, input_shape, None # ── Prediction ───────────────────────────── def predict(model, class_indices, img, input_shape): # Use model's expected input size if input_shape and len(input_shape) >= 3: target_size = (input_shape[1], input_shape[2]) # (height, width) else: target_size = (160, 160) # fallback img = img.convert("RGB").resize(target_size) arr = np.array(img, dtype=np.float32) / 255.0 arr = np.expand_dims(arr, axis=0) preds = model.predict(arr, verbose=0)[0] top5_idx = np.argsort(preds)[::-1][:5] top5 = [(class_indices.get(i, "Unknown"), float(preds[i]) * 100) for i in top5_idx] return top5[0][0], top5[0][1], top5 # ── UI ───────────────────────────── st.title("🌿 LeafScan - Plant Disease Detector") # Verify XSRF is disabled (optional - can remove after confirming) st.write("XSRF Protection:", st.get_option("server.enableXsrfProtection")) model, class_indices, input_shape, error = load_model() if error: st.error(error) st.stop() st.success(f"Model loaded — {len(class_indices)} classes") st.info(f"Model input shape: {input_shape}") uploaded = st.file_uploader("Upload a leaf image", type=["jpg", "png", "jpeg"]) if uploaded: img = Image.open(uploaded) st.image(img, caption="Uploaded Image", use_column_width=True) if st.button("Analyze"): try: with st.spinner("Predicting..."): class_name, conf, top5 = predict(model, class_indices, img, input_shape) plant, disease = class_name.split("___") if "___" in class_name else ("Unknown", class_name) st.subheader("Result") st.write("🌱 Plant:", plant.replace("_", " ")) st.write("🦠 Disease:", disease.replace("_", " ")) st.write(f"🎯 Confidence: {conf:.2f}%") st.subheader("Top Predictions") for name, p in top5: st.write(f"{name}: {p:.2f}%") except Exception as e: st.error(f"Prediction failed: {e}")