File size: 6,426 Bytes
3b985d4 4a29102 3b985d4 b926f62 3b985d4 3ccfc90 3b985d4 ff48f40 3b985d4 8e95371 3b985d4 8e95371 3b985d4 8e95371 3b985d4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 |
import streamlit as st
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from PIL import Image
import os
import random
import pandas as pd
import matplotlib.pyplot as plt
import zipfile
# β
First Streamlit command - required by Streamlit
st.set_page_config(page_title="Rice Disease Detection", layout="wide")
# ================= ZIP FILE HANDLING =================
DATASET_PATH = "rice_leaf_diseases"
ZIP_FILE = "rice_leaf_diseases.zip"
# Silent extraction without Streamlit messages
if not os.path.exists(DATASET_PATH):
if os.path.exists(ZIP_FILE):
with zipfile.ZipFile(ZIP_FILE, 'r') as zip_ref:
zip_ref.extractall(".") # Extract to current directory
# β
Load Class Names from Extracted Dataset
if os.path.exists(DATASET_PATH):
CLASS_NAMES = sorted(os.listdir(DATASET_PATH))
else:
CLASS_NAMES = ["Bacterial Leaf Blight", "Brown Spot", "Leaf Smut"] # Fallback
# ================= ORIGINAL APP CODE =================
# Define Model Class
class RiceDiseaseCNN(nn.Module):
def __init__(self, num_classes):
super(RiceDiseaseCNN, self).__init__()
self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(32)
self.bn2 = nn.BatchNorm2d(64)
self.bn3 = nn.BatchNorm2d(128)
self.pool = nn.MaxPool2d(2, 2)
self.dropout = nn.Dropout(0.4)
self.fc1 = nn.Linear(128 * 16 * 16, 512)
self.fc2 = nn.Linear(512, num_classes)
def forward(self, x):
x = self.pool(F.relu(self.bn1(self.conv1(x))))
x = self.pool(F.relu(self.bn2(self.conv2(x))))
x = self.pool(F.relu(self.bn3(self.conv3(x))))
x = x.view(x.size(0), -1)
x = F.relu(self.fc1(x))
x = self.dropout(x)
x = self.fc2(x)
return x
# Load Model
@st.cache_resource
def load_model():
device = torch.device("cpu")
model = RiceDiseaseCNN(len(CLASS_NAMES))
model.load_state_dict(torch.load("rice_disease_cnn.pth", map_location=device))
model.eval()
return model
model = load_model()
# Define Transformations
transform = transforms.Compose([
transforms.Resize((128, 128)),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# Class Labels
class_labels = ["Bacterial leaf blight", "Brown spot", "Leaf smut"]
# Define dataset path after extraction
dataset_path = DATASET_PATH
# Sidebar Navigation
st.sidebar.title("Navigation")
page = st.sidebar.radio("Go to", ["Dataset", "Data Visualization", "Model Metrics", "Classification"])
# Dataset Page
if page == "Dataset":
st.title("Rice Leaf Disease Dataset πΎ")
st.markdown("""
This dataset contains images of rice leaves affected by three common diseases:
- **Bacterial Leaf Blight**: Caused by *Xanthomonas oryzae* bacteria.
- **Brown Spot**: Caused by *Cochliobolus miyabeanus* fungus.
- **Leaf Smut**: Caused by *Entyloma oryzae* fungus.
The dataset is available on [Kaggle](https://www.kaggle.com/datasets/vbookshelf/rice-leaf-diseases).
""")
def get_sample_images(label, count=3):
label_path = os.path.join(dataset_path, label)
images = [img for img in os.listdir(label_path) if img.endswith(("png", "jpg", "jpeg"))]
sample_images = random.sample(images, min(count, len(images)))
return [os.path.join(label_path, img) for img in sample_images]
st.subheader("Sample Images from Dataset")
cols = st.columns(3)
for idx, label in enumerate(class_labels):
images = get_sample_images(label)
with cols[idx]:
st.write(f"### {label}")
for img_path in images:
st.image(img_path, use_container_width=True)
# Data Visualization Page
elif page == "Data Visualization":
st.title("Data Visualization π")
def get_image_count(label):
label_path = os.path.join(dataset_path, label)
return len([img for img in os.listdir(label_path) if img.endswith(("png", "jpg", "jpeg"))])
class_counts = {label: get_image_count(label) for label in class_labels}
st.subheader("Class Distribution")
df = pd.DataFrame(list(class_counts.items()), columns=["Disease", "Count"])
# Pie Chart
fig, ax = plt.subplots()
ax.pie(df["Count"], labels=df["Disease"], autopct='%1.1f%%', startangle=90)
ax.axis('equal')
st.pyplot(fig)
# Bar Chart
fig, ax = plt.subplots()
ax.bar(df["Disease"], df["Count"], color=['#1f77b4', '#ff7f0e', '#2ca02c'])
ax.set_xlabel('Disease Type')
ax.set_ylabel('Number of Images')
st.pyplot(fig)
# Model Metrics Page
elif page == "Model Metrics":
st.title("Model Performance Metrics π")
st.markdown("""
### Model Architecture
- **Convolutional Layers** with Batch Normalization
- **MaxPooling** for dimension reduction
- **Fully Connected Layers** for classification
""")
# Confusion Matrix
st.subheader("Confusion Matrix")
st.image("con_mat.png", use_container_width=True)
# Training Curves
col1, col2 = st.columns(2)
with col1:
st.subheader("Training Loss")
st.image("train_loss.png")
with col2:
st.subheader("Validation Accuracy")
st.image("val_acc.png")
# Classification Report
st.subheader("Classification Report")
st.code("""
precision recall f1-score support
Bacterial Leaf Blight 0.90 1.00 0.95 9
Brown Spot 1.00 1.00 1.00 5
Leaf Smut 1.00 0.75 0.86 4
""")
# Classification Page
elif page == "Classification":
st.title("Rice Leaf Disease Classification π")
uploaded_file = st.file_uploader("Upload rice leaf image", type=["jpg", "png", "jpeg"])
if uploaded_file:
image = Image.open(uploaded_file).convert("RGB")
st.image(image, use_container_width=True)
# Transform and predict
image_tensor = transform(image).unsqueeze(0)
with torch.no_grad():
output = model(image_tensor)
_, predicted = torch.max(output, 1)
st.success(f"**Prediction:** {class_labels[predicted.item()]}")
|