import streamlit as st import torch import torch.nn as nn import torch.nn.functional as F import torchvision.transforms as transforms from PIL import Image import os import random import pandas as pd import matplotlib.pyplot as plt import zipfile # ✅ First Streamlit command - required by Streamlit st.set_page_config(page_title="Rice Disease Detection", layout="wide") # ================= ZIP FILE HANDLING ================= DATASET_PATH = "rice_leaf_diseases" ZIP_FILE = "rice_leaf_diseases.zip" # Silent extraction without Streamlit messages if not os.path.exists(DATASET_PATH): if os.path.exists(ZIP_FILE): with zipfile.ZipFile(ZIP_FILE, 'r') as zip_ref: zip_ref.extractall(".") # Extract to current directory # ✅ Load Class Names from Extracted Dataset if os.path.exists(DATASET_PATH): CLASS_NAMES = sorted(os.listdir(DATASET_PATH)) else: CLASS_NAMES = ["Bacterial Leaf Blight", "Brown Spot", "Leaf Smut"] # Fallback # ================= ORIGINAL APP CODE ================= # Define Model Class class RiceDiseaseCNN(nn.Module): def __init__(self, num_classes): super(RiceDiseaseCNN, self).__init__() self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1) self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1) self.bn1 = nn.BatchNorm2d(32) self.bn2 = nn.BatchNorm2d(64) self.bn3 = nn.BatchNorm2d(128) self.pool = nn.MaxPool2d(2, 2) self.dropout = nn.Dropout(0.4) self.fc1 = nn.Linear(128 * 16 * 16, 512) self.fc2 = nn.Linear(512, num_classes) def forward(self, x): x = self.pool(F.relu(self.bn1(self.conv1(x)))) x = self.pool(F.relu(self.bn2(self.conv2(x)))) x = self.pool(F.relu(self.bn3(self.conv3(x)))) x = x.view(x.size(0), -1) x = F.relu(self.fc1(x)) x = self.dropout(x) x = self.fc2(x) return x # Load Model @st.cache_resource def load_model(): device = torch.device("cpu") model = RiceDiseaseCNN(len(CLASS_NAMES)) model.load_state_dict(torch.load("rice_disease_cnn.pth", map_location=device)) model.eval() return model model = load_model() # Define Transformations transform = transforms.Compose([ transforms.Resize((128, 128)), transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ]) # Class Labels class_labels = ["Bacterial leaf blight", "Brown spot", "Leaf smut"] # Define dataset path after extraction dataset_path = DATASET_PATH # Sidebar Navigation st.sidebar.title("Navigation") page = st.sidebar.radio("Go to", ["Dataset", "Data Visualization", "Model Metrics", "Classification"]) # Dataset Page if page == "Dataset": st.title("Rice Leaf Disease Dataset 🌾") st.markdown(""" This dataset contains images of rice leaves affected by three common diseases: - **Bacterial Leaf Blight**: Caused by *Xanthomonas oryzae* bacteria. - **Brown Spot**: Caused by *Cochliobolus miyabeanus* fungus. - **Leaf Smut**: Caused by *Entyloma oryzae* fungus. The dataset is available on [Kaggle](https://www.kaggle.com/datasets/vbookshelf/rice-leaf-diseases). """) def get_sample_images(label, count=3): label_path = os.path.join(dataset_path, label) images = [img for img in os.listdir(label_path) if img.endswith(("png", "jpg", "jpeg"))] sample_images = random.sample(images, min(count, len(images))) return [os.path.join(label_path, img) for img in sample_images] st.subheader("Sample Images from Dataset") cols = st.columns(3) for idx, label in enumerate(class_labels): images = get_sample_images(label) with cols[idx]: st.write(f"### {label}") for img_path in images: st.image(img_path, use_container_width=True) # Data Visualization Page elif page == "Data Visualization": st.title("Data Visualization 📊") def get_image_count(label): label_path = os.path.join(dataset_path, label) return len([img for img in os.listdir(label_path) if img.endswith(("png", "jpg", "jpeg"))]) class_counts = {label: get_image_count(label) for label in class_labels} st.subheader("Class Distribution") df = pd.DataFrame(list(class_counts.items()), columns=["Disease", "Count"]) # Pie Chart fig, ax = plt.subplots() ax.pie(df["Count"], labels=df["Disease"], autopct='%1.1f%%', startangle=90) ax.axis('equal') st.pyplot(fig) # Bar Chart fig, ax = plt.subplots() ax.bar(df["Disease"], df["Count"], color=['#1f77b4', '#ff7f0e', '#2ca02c']) ax.set_xlabel('Disease Type') ax.set_ylabel('Number of Images') st.pyplot(fig) # Model Metrics Page elif page == "Model Metrics": st.title("Model Performance Metrics 📈") st.markdown(""" ### Model Architecture - **Convolutional Layers** with Batch Normalization - **MaxPooling** for dimension reduction - **Fully Connected Layers** for classification """) # Confusion Matrix st.subheader("Confusion Matrix") st.image("con_mat.png", use_container_width=True) # Training Curves col1, col2 = st.columns(2) with col1: st.subheader("Training Loss") st.image("train_loss.png") with col2: st.subheader("Validation Accuracy") st.image("val_acc.png") # Classification Report st.subheader("Classification Report") st.code(""" precision recall f1-score support Bacterial Leaf Blight 0.90 1.00 0.95 9 Brown Spot 1.00 1.00 1.00 5 Leaf Smut 1.00 0.75 0.86 4 """) # Classification Page elif page == "Classification": st.title("Rice Leaf Disease Classification 🔍") uploaded_file = st.file_uploader("Upload rice leaf image", type=["jpg", "png", "jpeg"]) if uploaded_file: image = Image.open(uploaded_file).convert("RGB") st.image(image, use_container_width=True) # Transform and predict image_tensor = transform(image).unsqueeze(0) with torch.no_grad(): output = model(image_tensor) _, predicted = torch.max(output, 1) st.success(f"**Prediction:** {class_labels[predicted.item()]}")