| import streamlit as st |
| import numpy as np |
| import json |
| import os |
| from PIL import Image |
| import tensorflow as tf |
|
|
| |
| st.set_page_config( |
| page_title="LeafScan β Plant Disease Detector", |
| page_icon="πΏ", |
| layout="centered", |
| ) |
|
|
| |
| @st.cache_resource |
| def load_model(): |
| model_path = "plant_model.keras" |
| indices_path = "class_indices.json" |
|
|
| |
| current_dir = os.getcwd() |
| files = os.listdir(current_dir) |
| debug_files = f"Files in {current_dir}: {files}" |
|
|
| if not os.path.exists(model_path): |
| return None, None, None, f"Model not found: {model_path}. {debug_files}" |
|
|
| if not os.path.exists(indices_path): |
| return None, None, None, f"class_indices.json not found: {indices_path}. {debug_files}" |
|
|
| try: |
| model = tf.keras.models.load_model(model_path, compile=False) |
| except Exception as e: |
| return None, None, None, f"Failed to load model: {e}. {debug_files}" |
|
|
| try: |
| with open(indices_path, "r") as f: |
| raw = json.load(f) |
| except Exception as e: |
| return None, None, None, f"Failed to load class indices: {e}" |
|
|
| |
| if all(str(k).isdigit() for k in raw.keys()): |
| class_indices = {int(k): v for k, v in raw.items()} |
| else: |
| class_indices = {v: k for k, v in raw.items()} |
|
|
| input_shape = model.input_shape |
| return model, class_indices, input_shape, None |
|
|
|
|
| |
| def predict(model, class_indices, img, input_shape): |
| |
| if input_shape and len(input_shape) >= 3: |
| target_size = (input_shape[1], input_shape[2]) |
| else: |
| target_size = (160, 160) |
|
|
| img = img.convert("RGB").resize(target_size) |
| arr = np.array(img, dtype=np.float32) / 255.0 |
| arr = np.expand_dims(arr, axis=0) |
|
|
| preds = model.predict(arr, verbose=0)[0] |
|
|
| top5_idx = np.argsort(preds)[::-1][:5] |
| top5 = [(class_indices.get(i, "Unknown"), float(preds[i]) * 100) for i in top5_idx] |
|
|
| return top5[0][0], top5[0][1], top5 |
|
|
|
|
| |
| st.title("πΏ LeafScan - Plant Disease Detector") |
|
|
| |
| st.write("XSRF Protection:", st.get_option("server.enableXsrfProtection")) |
|
|
| model, class_indices, input_shape, error = load_model() |
|
|
| if error: |
| st.error(error) |
| st.stop() |
|
|
| st.success(f"Model loaded β {len(class_indices)} classes") |
| st.info(f"Model input shape: {input_shape}") |
|
|
| uploaded = st.file_uploader("Upload a leaf image", type=["jpg", "png", "jpeg"]) |
|
|
| if uploaded: |
| img = Image.open(uploaded) |
| st.image(img, caption="Uploaded Image", use_column_width=True) |
|
|
| if st.button("Analyze"): |
| try: |
| with st.spinner("Predicting..."): |
| class_name, conf, top5 = predict(model, class_indices, img, input_shape) |
|
|
| plant, disease = class_name.split("___") if "___" in class_name else ("Unknown", class_name) |
|
|
| st.subheader("Result") |
| st.write("π± Plant:", plant.replace("_", " ")) |
| st.write("π¦ Disease:", disease.replace("_", " ")) |
| st.write(f"π― Confidence: {conf:.2f}%") |
|
|
| st.subheader("Top Predictions") |
| for name, p in top5: |
| st.write(f"{name}: {p:.2f}%") |
| except Exception as e: |
| st.error(f"Prediction failed: {e}") |
|
|