Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from utils import predict, get_class_probabilities | |
| from PIL import Image | |
| import os | |
| import json | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| import tempfile | |
| st.set_page_config(page_title="Rose Disease Detection", layout="centered") | |
| st.title("🌹 Rose Disease Detection") | |
| st.write("Upload a rose leaf image to detect diseases") | |
| # Add description of possible classes | |
| st.markdown(""" | |
| ### Possible Classifications: | |
| - **Healthy Leaf Rose**: Healthy rose leaves | |
| - **Rose Rust**: Rose leaves affected by rust disease | |
| - **Rose Sawfly/Rose Slug**: Rose leaves affected by sawfly or slug damage | |
| """) | |
| # Create necessary directories | |
| os.makedirs('models', exist_ok=True) | |
| os.makedirs('temp', exist_ok=True) | |
| uploaded_file = st.file_uploader("Upload a Rose Leaf Image", type=["jpg", "png", "jpeg"]) | |
| if uploaded_file is not None: | |
| try: | |
| # Save the uploaded file to a temporary location | |
| with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg', dir='temp') as tmp_file: | |
| tmp_file.write(uploaded_file.getvalue()) | |
| temp_path = tmp_file.name | |
| # Open and display the image | |
| image = Image.open(temp_path) | |
| st.image(image, caption="Uploaded Image", use_container_width=True) | |
| if st.button("Detect Disease"): | |
| model_path = "models/rose_model.h5" | |
| if not os.path.exists(model_path): | |
| st.error("Model not found. Please train the model first.") | |
| st.stop() | |
| try: | |
| # Use the temporary file for prediction | |
| label, confidence = predict(model_path, temp_path) | |
| # Customize the output based on the prediction | |
| if "Healthy" in label: | |
| st.success(f"**Prediction**: {label} ({confidence*100:.2f}% confidence)") | |
| else: | |
| st.warning(f"**Prediction**: {label} ({confidence*100:.2f}% confidence)") | |
| st.info("⚠️ This leaf appears to be affected by a disease. Please take appropriate measures.") | |
| # Display probability distribution | |
| st.subheader("Probability Distribution") | |
| probabilities = get_class_probabilities(model_path, temp_path) | |
| # Create a bar chart of probabilities | |
| plt.figure(figsize=(10, 6)) | |
| classes = list(probabilities.keys()) | |
| probs = list(probabilities.values()) | |
| plt.bar(classes, probs) | |
| plt.xticks(rotation=45, ha='right') | |
| plt.ylabel('Probability') | |
| plt.title('Class Probabilities') | |
| plt.tight_layout() | |
| st.pyplot(plt) | |
| # Display metrics if available | |
| metrics_path = "models/metrics.json" | |
| if os.path.exists(metrics_path): | |
| with open(metrics_path, "r") as f: | |
| metrics = json.load(f) | |
| st.subheader("Model Performance Metrics") | |
| col1, col2, col3 = st.columns(3) | |
| with col1: | |
| st.metric("Accuracy", f"{metrics.get('accuracy', 0):.2%}") | |
| with col2: | |
| st.metric("Precision", f"{metrics.get('precision', 0):.2%}") | |
| with col3: | |
| st.metric("Recall", f"{metrics.get('recall', 0):.2%}") | |
| # Display confusion matrix if available | |
| cm_path = "models/confusion_matrix.json" | |
| if os.path.exists(cm_path): | |
| with open(cm_path, "r") as f: | |
| cm = np.array(json.load(f)) | |
| plt.figure(figsize=(10, 8)) | |
| sns.heatmap(cm, annot=True, fmt='d', cmap='Blues') | |
| plt.title('Confusion Matrix') | |
| plt.ylabel('True Label') | |
| plt.xlabel('Predicted Label') | |
| st.pyplot(plt) | |
| except Exception as e: | |
| st.error(f"Error during prediction: {str(e)}") | |
| st.info("Please make sure the model is trained and available in the models directory.") | |
| finally: | |
| # Clean up the temporary file | |
| try: | |
| os.unlink(temp_path) | |
| except: | |
| pass | |
| except Exception as e: | |
| st.error(f"Error processing the uploaded file: {str(e)}") | |
| st.info("Please try uploading the image again.") |