File size: 4,611 Bytes
b634f72
df61dd1
 
 
 
 
 
 
20f4b67
df61dd1
 
 
 
 
 
 
 
 
 
 
 
 
20f4b67
 
 
 
df61dd1
 
 
20f4b67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
df61dd1
20f4b67
 
 
 
 
 
 
 
 
 
df61dd1
20f4b67
 
 
 
 
 
 
 
 
 
 
 
 
 
df61dd1
20f4b67
 
 
 
 
 
 
 
 
 
 
 
 
df61dd1
20f4b67
 
 
 
 
 
 
 
 
 
 
df61dd1
20f4b67
 
 
df61dd1
20f4b67
 
 
 
 
 
b634f72
20f4b67
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import streamlit as st
from utils import predict, get_class_probabilities
from PIL import Image
import os
import json
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import tempfile

st.set_page_config(page_title="Rose Disease Detection", layout="centered")
st.title("🌹 Rose Disease Detection")
st.write("Upload a rose leaf image to detect diseases")

# Add description of possible classes
st.markdown("""
### Possible Classifications:
- **Healthy Leaf Rose**: Healthy rose leaves
- **Rose Rust**: Rose leaves affected by rust disease
- **Rose Sawfly/Rose Slug**: Rose leaves affected by sawfly or slug damage
""")

# Create necessary directories
os.makedirs('models', exist_ok=True)
os.makedirs('temp', exist_ok=True)

uploaded_file = st.file_uploader("Upload a Rose Leaf Image", type=["jpg", "png", "jpeg"])

if uploaded_file is not None:
    try:
        # Save the uploaded file to a temporary location
        with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg', dir='temp') as tmp_file:
            tmp_file.write(uploaded_file.getvalue())
            temp_path = tmp_file.name

        # Open and display the image
        image = Image.open(temp_path)
        st.image(image, caption="Uploaded Image", use_container_width=True)

        if st.button("Detect Disease"):
            model_path = "models/rose_model.h5"
            if not os.path.exists(model_path):
                st.error("Model not found. Please train the model first.")
                st.stop()

            try:
                # Use the temporary file for prediction
                label, confidence = predict(model_path, temp_path)
                
                # Customize the output based on the prediction
                if "Healthy" in label:
                    st.success(f"**Prediction**: {label} ({confidence*100:.2f}% confidence)")
                else:
                    st.warning(f"**Prediction**: {label} ({confidence*100:.2f}% confidence)")
                    st.info("⚠️ This leaf appears to be affected by a disease. Please take appropriate measures.")

                # Display probability distribution
                st.subheader("Probability Distribution")
                probabilities = get_class_probabilities(model_path, temp_path)
                
                # Create a bar chart of probabilities
                plt.figure(figsize=(10, 6))
                classes = list(probabilities.keys())
                probs = list(probabilities.values())
                plt.bar(classes, probs)
                plt.xticks(rotation=45, ha='right')
                plt.ylabel('Probability')
                plt.title('Class Probabilities')
                plt.tight_layout()
                st.pyplot(plt)

                # Display metrics if available
                metrics_path = "models/metrics.json"
                if os.path.exists(metrics_path):
                    with open(metrics_path, "r") as f:
                        metrics = json.load(f)
                        st.subheader("Model Performance Metrics")
                        col1, col2, col3 = st.columns(3)
                        with col1:
                            st.metric("Accuracy", f"{metrics.get('accuracy', 0):.2%}")
                        with col2:
                            st.metric("Precision", f"{metrics.get('precision', 0):.2%}")
                        with col3:
                            st.metric("Recall", f"{metrics.get('recall', 0):.2%}")

                # Display confusion matrix if available
                cm_path = "models/confusion_matrix.json"
                if os.path.exists(cm_path):
                    with open(cm_path, "r") as f:
                        cm = np.array(json.load(f))
                        plt.figure(figsize=(10, 8))
                        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
                        plt.title('Confusion Matrix')
                        plt.ylabel('True Label')
                        plt.xlabel('Predicted Label')
                        st.pyplot(plt)

            except Exception as e:
                st.error(f"Error during prediction: {str(e)}")
                st.info("Please make sure the model is trained and available in the models directory.")

            finally:
                # Clean up the temporary file
                try:
                    os.unlink(temp_path)
                except:
                    pass

    except Exception as e:
        st.error(f"Error processing the uploaded file: {str(e)}")
        st.info("Please try uploading the image again.")