import streamlit as st from utils import predict, get_class_probabilities from PIL import Image import os import json import numpy as np import matplotlib.pyplot as plt import seaborn as sns import tempfile st.set_page_config(page_title="Rose Disease Detection", layout="centered") st.title("🌹 Rose Disease Detection") st.write("Upload a rose leaf image to detect diseases") # Add description of possible classes st.markdown(""" ### Possible Classifications: - **Healthy Leaf Rose**: Healthy rose leaves - **Rose Rust**: Rose leaves affected by rust disease - **Rose Sawfly/Rose Slug**: Rose leaves affected by sawfly or slug damage """) # Create necessary directories os.makedirs('models', exist_ok=True) os.makedirs('temp', exist_ok=True) uploaded_file = st.file_uploader("Upload a Rose Leaf Image", type=["jpg", "png", "jpeg"]) if uploaded_file is not None: try: # Save the uploaded file to a temporary location with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg', dir='temp') as tmp_file: tmp_file.write(uploaded_file.getvalue()) temp_path = tmp_file.name # Open and display the image image = Image.open(temp_path) st.image(image, caption="Uploaded Image", use_container_width=True) if st.button("Detect Disease"): model_path = "models/rose_model.h5" if not os.path.exists(model_path): st.error("Model not found. Please train the model first.") st.stop() try: # Use the temporary file for prediction label, confidence = predict(model_path, temp_path) # Customize the output based on the prediction if "Healthy" in label: st.success(f"**Prediction**: {label} ({confidence*100:.2f}% confidence)") else: st.warning(f"**Prediction**: {label} ({confidence*100:.2f}% confidence)") st.info("⚠️ This leaf appears to be affected by a disease. Please take appropriate measures.") # Display probability distribution st.subheader("Probability Distribution") probabilities = get_class_probabilities(model_path, temp_path) # Create a bar chart of probabilities plt.figure(figsize=(10, 6)) classes = list(probabilities.keys()) probs = list(probabilities.values()) plt.bar(classes, probs) plt.xticks(rotation=45, ha='right') plt.ylabel('Probability') plt.title('Class Probabilities') plt.tight_layout() st.pyplot(plt) # Display metrics if available metrics_path = "models/metrics.json" if os.path.exists(metrics_path): with open(metrics_path, "r") as f: metrics = json.load(f) st.subheader("Model Performance Metrics") col1, col2, col3 = st.columns(3) with col1: st.metric("Accuracy", f"{metrics.get('accuracy', 0):.2%}") with col2: st.metric("Precision", f"{metrics.get('precision', 0):.2%}") with col3: st.metric("Recall", f"{metrics.get('recall', 0):.2%}") # Display confusion matrix if available cm_path = "models/confusion_matrix.json" if os.path.exists(cm_path): with open(cm_path, "r") as f: cm = np.array(json.load(f)) plt.figure(figsize=(10, 8)) sns.heatmap(cm, annot=True, fmt='d', cmap='Blues') plt.title('Confusion Matrix') plt.ylabel('True Label') plt.xlabel('Predicted Label') st.pyplot(plt) except Exception as e: st.error(f"Error during prediction: {str(e)}") st.info("Please make sure the model is trained and available in the models directory.") finally: # Clean up the temporary file try: os.unlink(temp_path) except: pass except Exception as e: st.error(f"Error processing the uploaded file: {str(e)}") st.info("Please try uploading the image again.")