File size: 6,426 Bytes
3b985d4
 
 
 
 
 
4a29102
3b985d4
 
 
b926f62
 
3b985d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3ccfc90
3b985d4
 
 
ff48f40
3b985d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e95371
3b985d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e95371
3b985d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e95371
3b985d4
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
import streamlit as st
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from PIL import Image
import os
import random
import pandas as pd
import matplotlib.pyplot as plt
import zipfile

# βœ… First Streamlit command - required by Streamlit
st.set_page_config(page_title="Rice Disease Detection", layout="wide")

# ================= ZIP FILE HANDLING =================
DATASET_PATH = "rice_leaf_diseases"
ZIP_FILE = "rice_leaf_diseases.zip"

# Silent extraction without Streamlit messages
if not os.path.exists(DATASET_PATH):
    if os.path.exists(ZIP_FILE):
        with zipfile.ZipFile(ZIP_FILE, 'r') as zip_ref:
            zip_ref.extractall(".")  # Extract to current directory

# βœ… Load Class Names from Extracted Dataset
if os.path.exists(DATASET_PATH):
    CLASS_NAMES = sorted(os.listdir(DATASET_PATH))
else:
    CLASS_NAMES = ["Bacterial Leaf Blight", "Brown Spot", "Leaf Smut"]  # Fallback

# ================= ORIGINAL APP CODE =================
# Define Model Class
class RiceDiseaseCNN(nn.Module):
    def __init__(self, num_classes):
        super(RiceDiseaseCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.bn2 = nn.BatchNorm2d(64)
        self.bn3 = nn.BatchNorm2d(128)
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout = nn.Dropout(0.4)
        self.fc1 = nn.Linear(128 * 16 * 16, 512)
        self.fc2 = nn.Linear(512, num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.bn1(self.conv1(x))))
        x = self.pool(F.relu(self.bn2(self.conv2(x))))
        x = self.pool(F.relu(self.bn3(self.conv3(x))))
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

# Load Model
@st.cache_resource
def load_model():
    device = torch.device("cpu")
    model = RiceDiseaseCNN(len(CLASS_NAMES))
    model.load_state_dict(torch.load("rice_disease_cnn.pth", map_location=device))
    model.eval()
    return model

model = load_model()

# Define Transformations
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Class Labels
class_labels = ["Bacterial leaf blight", "Brown spot", "Leaf smut"]


# Define dataset path after extraction
dataset_path = DATASET_PATH

# Sidebar Navigation
st.sidebar.title("Navigation")
page = st.sidebar.radio("Go to", ["Dataset", "Data Visualization", "Model Metrics", "Classification"])

# Dataset Page
if page == "Dataset":
    st.title("Rice Leaf Disease Dataset 🌾")
    st.markdown("""
        This dataset contains images of rice leaves affected by three common diseases:
        - **Bacterial Leaf Blight**: Caused by *Xanthomonas oryzae* bacteria.
        - **Brown Spot**: Caused by *Cochliobolus miyabeanus* fungus.
        - **Leaf Smut**: Caused by *Entyloma oryzae* fungus.
        The dataset is available on [Kaggle](https://www.kaggle.com/datasets/vbookshelf/rice-leaf-diseases).
    """)
    
    def get_sample_images(label, count=3):
        label_path = os.path.join(dataset_path, label)
        images = [img for img in os.listdir(label_path) if img.endswith(("png", "jpg", "jpeg"))]
        sample_images = random.sample(images, min(count, len(images)))
        return [os.path.join(label_path, img) for img in sample_images]

    st.subheader("Sample Images from Dataset")
    cols = st.columns(3)
    for idx, label in enumerate(class_labels):
        images = get_sample_images(label)
        with cols[idx]:
            st.write(f"### {label}")
            for img_path in images:
                st.image(img_path, use_container_width=True)

# Data Visualization Page
elif page == "Data Visualization":
    st.title("Data Visualization πŸ“Š")
    
    def get_image_count(label):
        label_path = os.path.join(dataset_path, label)
        return len([img for img in os.listdir(label_path) if img.endswith(("png", "jpg", "jpeg"))])

    class_counts = {label: get_image_count(label) for label in class_labels}
    
    st.subheader("Class Distribution")
    df = pd.DataFrame(list(class_counts.items()), columns=["Disease", "Count"])
    
    # Pie Chart
    fig, ax = plt.subplots()
    ax.pie(df["Count"], labels=df["Disease"], autopct='%1.1f%%', startangle=90)
    ax.axis('equal')
    st.pyplot(fig)
    
    # Bar Chart
    fig, ax = plt.subplots()
    ax.bar(df["Disease"], df["Count"], color=['#1f77b4', '#ff7f0e', '#2ca02c'])
    ax.set_xlabel('Disease Type')
    ax.set_ylabel('Number of Images')
    st.pyplot(fig)

# Model Metrics Page
elif page == "Model Metrics":
    st.title("Model Performance Metrics πŸ“ˆ")
    st.markdown("""
        ### Model Architecture
        - **Convolutional Layers** with Batch Normalization
        - **MaxPooling** for dimension reduction
        - **Fully Connected Layers** for classification
    """)
    
    # Confusion Matrix
    st.subheader("Confusion Matrix")
    st.image("con_mat.png", use_container_width=True)
    
    # Training Curves
    col1, col2 = st.columns(2)
    with col1:
        st.subheader("Training Loss")
        st.image("train_loss.png")
    with col2:
        st.subheader("Validation Accuracy")
        st.image("val_acc.png")
    
    # Classification Report
    st.subheader("Classification Report")
    st.code(""" 
        precision    recall  f1-score   support
Bacterial Leaf Blight       0.90      1.00      0.95         9
        Brown Spot       1.00      1.00      1.00         5
         Leaf Smut       1.00      0.75      0.86         4
    """)

# Classification Page
elif page == "Classification":
    st.title("Rice Leaf Disease Classification πŸ”")
    
    uploaded_file = st.file_uploader("Upload rice leaf image", type=["jpg", "png", "jpeg"])
    
    if uploaded_file:
        image = Image.open(uploaded_file).convert("RGB")
        st.image(image, use_container_width=True)
        
        # Transform and predict
        image_tensor = transform(image).unsqueeze(0)
        with torch.no_grad():
            output = model(image_tensor)
            _, predicted = torch.max(output, 1)
        
        st.success(f"**Prediction:** {class_labels[predicted.item()]}")