Regino commited on
Commit
4a29102
Β·
1 Parent(s): 1bd22cb

first commit

Browse files
Files changed (9) hide show
  1. .gitignore +2 -0
  2. app.py +180 -0
  3. con_mat.png +0 -0
  4. requirements.txt +6 -0
  5. rice_disease_cnn.pth +3 -0
  6. train.ipynb +0 -0
  7. train2.ipynb +0 -0
  8. train_loss.png +0 -0
  9. val_acc.png +0 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ rice_leaf_diseases/
2
+ rice_leaf_diseases.zip
app.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torchvision.transforms as transforms
6
+ from PIL import Image
7
+ import os
8
+ import random
9
+ import pandas as pd
10
+ import matplotlib.pyplot as plt
11
+ import zipfile # Added for zip extraction
12
+
13
+ # ---- Dataset Extraction Logic ----
14
+ DATASET_ZIP = "rice_leaf_diseases.zip"
15
+ DATASET_FOLDER = "rice_leaf_diseases"
16
+
17
+ # Check and extract dataset if needed
18
+ if not os.path.exists(DATASET_FOLDER):
19
+ if os.path.exists(DATASET_ZIP):
20
+ with zipfile.ZipFile(DATASET_ZIP, 'r') as zip_ref:
21
+ zip_ref.extractall()
22
+ st.success("Dataset extracted successfully!")
23
+ else:
24
+ st.error(f"Dataset zip file '{DATASET_ZIP}' not found!")
25
+ else:
26
+ st.toast("Dataset already available!", icon="βœ…")
27
+
28
+ # Define Model Class
29
+ class RiceDiseaseCNN(nn.Module):
30
+ def __init__(self, num_classes):
31
+ super(RiceDiseaseCNN, self).__init__()
32
+ self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
33
+ self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
34
+ self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
35
+ self.bn1 = nn.BatchNorm2d(32)
36
+ self.bn2 = nn.BatchNorm2d(64)
37
+ self.bn3 = nn.BatchNorm2d(128)
38
+ self.pool = nn.MaxPool2d(2, 2)
39
+ self.dropout = nn.Dropout(0.4)
40
+ self.fc1 = nn.Linear(128 * 16 * 16, 512)
41
+ self.fc2 = nn.Linear(512, num_classes)
42
+
43
+ def forward(self, x):
44
+ x = self.pool(F.relu(self.bn1(self.conv1(x))))
45
+ x = self.pool(F.relu(self.bn2(self.conv2(x))))
46
+ x = self.pool(F.relu(self.bn3(self.conv3(x))))
47
+ x = x.view(x.size(0), -1)
48
+ x = F.relu(self.fc1(x))
49
+ x = self.dropout(x)
50
+ x = self.fc2(x)
51
+ return x
52
+
53
+ # Load Model
54
+ device = torch.device("cpu")
55
+ num_classes = 3
56
+ model = RiceDiseaseCNN(num_classes)
57
+ model.load_state_dict(torch.load("rice_disease_cnn.pth", map_location=device))
58
+ model.eval()
59
+
60
+ # Define Transformations
61
+ transform = transforms.Compose([
62
+ transforms.Resize((128, 128)),
63
+ transforms.ToTensor(),
64
+ transforms.Normalize((0.5,), (0.5,))
65
+ ])
66
+
67
+ # Class Labels
68
+ class_labels = ["Bacterial Leaf Blight", "Brown Spot", "Leaf Smut"]
69
+
70
+ # Streamlit App Configuration
71
+ st.set_page_config(page_title="Rice Disease Detection", layout="wide")
72
+
73
+ # Define dataset path after extraction
74
+ dataset_path = DATASET_FOLDER
75
+
76
+ # Sidebar Navigation
77
+ st.sidebar.title("Navigation")
78
+ page = st.sidebar.radio("Go to", ["Dataset", "Data Visualization", "Model Metrics", "Classification"])
79
+
80
+ # Dataset Page
81
+ if page == "Dataset":
82
+ st.title("Rice Leaf Disease Dataset 🌾")
83
+ st.markdown("""
84
+ This dataset contains images of rice leaves affected by three common diseases:
85
+ - **Bacterial Leaf Blight**: Caused by *Xanthomonas oryzae* bacteria.
86
+ - **Brown Spot**: Caused by *Cochliobolus miyabeanus* fungus.
87
+ - **Leaf Smut**: Caused by *Entyloma oryzae* fungus.
88
+ The dataset is available on [Kaggle](https://www.kaggle.com/datasets/vbookshelf/rice-leaf-diseases).
89
+ """)
90
+
91
+ def get_sample_images(label, count=3):
92
+ label_path = os.path.join(dataset_path, label)
93
+ images = [img for img in os.listdir(label_path) if img.endswith(("png", "jpg", "jpeg"))]
94
+ sample_images = random.sample(images, min(count, len(images)))
95
+ return [os.path.join(label_path, img) for img in sample_images]
96
+
97
+ st.subheader("Sample Images from Dataset")
98
+ cols = st.columns(3)
99
+ for idx, label in enumerate(class_labels):
100
+ images = get_sample_images(label)
101
+ with cols[idx]:
102
+ st.write(f"### {label}")
103
+ for img_path in images:
104
+ st.image(img_path, use_column_width=True)
105
+
106
+ # Data Visualization Page
107
+ elif page == "Data Visualization":
108
+ st.title("Data Visualization πŸ“Š")
109
+
110
+ def get_image_count(label):
111
+ label_path = os.path.join(dataset_path, label)
112
+ return len([img for img in os.listdir(label_path) if img.endswith(("png", "jpg", "jpeg"))])
113
+
114
+ class_counts = {label: get_image_count(label) for label in class_labels}
115
+
116
+ st.subheader("Class Distribution")
117
+ df = pd.DataFrame(list(class_counts.items()), columns=["Disease", "Count"])
118
+
119
+ # Pie Chart
120
+ fig, ax = plt.subplots()
121
+ ax.pie(df["Count"], labels=df["Disease"], autopct='%1.1f%%', startangle=90)
122
+ ax.axis('equal')
123
+ st.pyplot(fig)
124
+
125
+ # Bar Chart
126
+ fig, ax = plt.subplots()
127
+ ax.bar(df["Disease"], df["Count"], color=['#1f77b4', '#ff7f0e', '#2ca02c'])
128
+ ax.set_xlabel('Disease Type')
129
+ ax.set_ylabel('Number of Images')
130
+ st.pyplot(fig)
131
+
132
+ # Model Metrics Page
133
+ elif page == "Model Metrics":
134
+ st.title("Model Performance Metrics πŸ“ˆ")
135
+ st.markdown("""
136
+ ### Model Architecture
137
+ - **Convolutional Layers** with Batch Normalization
138
+ - **MaxPooling** for dimension reduction
139
+ - **Fully Connected Layers** for classification
140
+ """)
141
+
142
+ # Confusion Matrix
143
+ st.subheader("Confusion Matrix")
144
+ st.image("con_mat.png", use_column_width=True)
145
+
146
+ # Training Curves
147
+ col1, col2 = st.columns(2)
148
+ with col1:
149
+ st.subheader("Training Loss")
150
+ st.image("train_loss.png")
151
+ with col2:
152
+ st.subheader("Validation Accuracy")
153
+ st.image("val_acc.png")
154
+
155
+ # Classification Report
156
+ st.subheader("Classification Report")
157
+ st.code("""
158
+ precision recall f1-score support
159
+ Bacterial Leaf Blight 0.90 1.00 0.95 9
160
+ Brown Spot 1.00 1.00 1.00 5
161
+ Leaf Smut 1.00 0.75 0.86 4
162
+ """)
163
+
164
+ # Classification Page
165
+ elif page == "Classification":
166
+ st.title("Rice Leaf Disease Classification πŸ”")
167
+
168
+ uploaded_file = st.file_uploader("Upload rice leaf image", type=["jpg", "png", "jpeg"])
169
+
170
+ if uploaded_file:
171
+ image = Image.open(uploaded_file).convert("RGB")
172
+ st.image(image, use_column_width=True)
173
+
174
+ # Transform and predict
175
+ image_tensor = transform(image).unsqueeze(0)
176
+ with torch.no_grad():
177
+ output = model(image_tensor)
178
+ _, predicted = torch.max(output, 1)
179
+
180
+ st.success(f"**Prediction:** {class_labels[predicted.item()]}")
con_mat.png ADDED
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ streamlit
2
+ torch
3
+ torchvision
4
+ pandas
5
+ matplotlib
6
+ pillow
rice_disease_cnn.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c8f7dc36dd4d5d7ba02214e06cef5d0bf960c26cda6d62e58d5dffe442a90df4
3
+ size 67501584
train.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
train2.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
train_loss.png ADDED
val_acc.png ADDED