Regino commited on
Commit
138a538
Β·
1 Parent(s): a3f4e86

new comit

Browse files
Files changed (5) hide show
  1. app.py +117 -0
  2. class_names.txt +16 -0
  3. model.py +40 -0
  4. plant_disease_model.pth +3 -0
  5. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ import torch
4
+ import torch.nn as nn
5
+ import torchvision.transforms as transforms
6
+ import torchvision.models as models
7
+ import matplotlib.pyplot as plt
8
+ import seaborn as sns
9
+ from PIL import Image
10
+ from torchvision import datasets
11
+ from torch.utils.data import DataLoader
12
+ from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
13
+ import random
14
+
15
+ # βœ… Sidebar Navigation
16
+ st.sidebar.title("Navigation")
17
+ page = st.sidebar.radio("Go to", ["Dataset", "Visualizations", "Model Metrics", "Disease Predictor"])
18
+
19
+ # βœ… Dataset Path
20
+ DATASET_PATH = "dataset/train" # Update if needed
21
+ CLASS_NAMES = os.listdir(DATASET_PATH) # Get class names from folder structure
22
+
23
+ # βœ… Load Model
24
+ @st.cache_resource
25
+ def load_model():
26
+ model = models.mobilenet_v2(pretrained=False)
27
+ model.classifier[1] = nn.Linear(model.classifier[1].in_features, len(CLASS_NAMES))
28
+ model.load_state_dict(torch.load("plant_disease_model.pth", map_location=torch.device("cpu")))
29
+ model.eval()
30
+ return model
31
+
32
+ model = load_model()
33
+
34
+ # βœ… Dataset Page – Show Sample Images
35
+ if page == "Dataset":
36
+ st.title("πŸ“Š Dataset Preview")
37
+ st.write(f"### Classes: {CLASS_NAMES}")
38
+
39
+ # Show sample images from each class
40
+ cols = st.columns(4)
41
+ for i, class_name in enumerate(CLASS_NAMES[:4]): # Show 4 classes
42
+ class_path = os.path.join(DATASET_PATH, class_name)
43
+ image_name = random.choice(os.listdir(class_path))
44
+ image_path = os.path.join(class_path, image_name)
45
+ image = Image.open(image_path)
46
+ cols[i].image(image, caption=class_name, use_column_width=True)
47
+
48
+ # βœ… Visualizations Page – Show Class Distribution
49
+ elif page == "Visualizations":
50
+ st.title("πŸ“ˆ Dataset Visualizations")
51
+
52
+ # Count images per class
53
+ class_counts = {cls: len(os.listdir(os.path.join(DATASET_PATH, cls))) for cls in CLASS_NAMES}
54
+
55
+ # Pie Chart
56
+ st.write("### Disease Distribution")
57
+ fig, ax = plt.subplots()
58
+ ax.pie(class_counts.values(), labels=class_counts.keys(), autopct='%1.1f%%', colors=plt.cm.viridis.colors)
59
+ st.pyplot(fig)
60
+
61
+ # Bar Chart
62
+ st.write("### Class Count")
63
+ fig, ax = plt.subplots()
64
+ sns.barplot(x=list(class_counts.keys()), y=list(class_counts.values()), palette="viridis")
65
+ plt.xticks(rotation=45)
66
+ st.pyplot(fig)
67
+
68
+ # βœ… Model Metrics Page
69
+ elif page == "Model Metrics":
70
+ st.title("πŸ“Š Model Performance")
71
+
72
+ # Load True Labels and Predictions
73
+ y_true = torch.load("y_true.pth")
74
+ y_pred = torch.load("y_pred.pth")
75
+
76
+ # Accuracy
77
+ accuracy = accuracy_score(y_true, y_pred)
78
+ st.write(f"### βœ… Accuracy: {accuracy:.2f}")
79
+
80
+ # Classification Report
81
+ st.write("### πŸ“‹ Classification Report")
82
+ report = classification_report(y_true, y_pred, target_names=CLASS_NAMES, output_dict=True)
83
+ st.write(pd.DataFrame(report).T)
84
+
85
+ # Confusion Matrix
86
+ st.write("### πŸ”€ Confusion Matrix")
87
+ cm = confusion_matrix(y_true, y_pred)
88
+ fig, ax = plt.subplots()
89
+ sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES)
90
+ st.pyplot(fig)
91
+
92
+ # βœ… Disease Predictor Page
93
+ elif page == "Disease Predictor":
94
+ st.title("🌿 Plant Disease Classifier")
95
+
96
+ # File Upload
97
+ uploaded_file = st.file_uploader("Upload a plant leaf image", type=["jpg", "png", "jpeg"])
98
+
99
+ if uploaded_file is not None:
100
+ image = Image.open(uploaded_file)
101
+ st.image(image, caption="Uploaded Image", use_column_width=True)
102
+
103
+ # Transform Image
104
+ transform = transforms.Compose([
105
+ transforms.Resize((128, 128)),
106
+ transforms.ToTensor(),
107
+ transforms.Normalize([0.5], [0.5])
108
+ ])
109
+
110
+ image_tensor = transform(image).unsqueeze(0)
111
+
112
+ # Predict Disease
113
+ with torch.no_grad():
114
+ output = model(image_tensor)
115
+ predicted_class = torch.argmax(output, dim=1).item()
116
+
117
+ st.write(f"### βœ… Prediction: {CLASS_NAMES[predicted_class]}")
class_names.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Pepper__bell___Bacterial_spot
2
+ Pepper__bell___healthy
3
+ PlantVillage
4
+ Potato___Early_blight
5
+ Potato___Late_blight
6
+ Potato___healthy
7
+ Tomato_Bacterial_spot
8
+ Tomato_Early_blight
9
+ Tomato_Late_blight
10
+ Tomato_Leaf_Mold
11
+ Tomato_Septoria_leaf_spot
12
+ Tomato_Spider_mites_Two_spotted_spider_mite
13
+ Tomato__Target_Spot
14
+ Tomato__Tomato_YellowLeaf__Curl_Virus
15
+ Tomato__Tomato_mosaic_virus
16
+ Tomato_healthy
model.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision.transforms as transforms
3
+ from torchvision import models
4
+ from PIL import Image
5
+
6
+ # βœ… Load Class Names
7
+ with open("class_names.txt", "r") as f:
8
+ class_names = [line.strip() for line in f.readlines()]
9
+
10
+ # βœ… Load Trained Model
11
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+ model = models.mobilenet_v2(pretrained=False)
13
+ model.classifier[1] = torch.nn.Linear(model.classifier[1].in_features, len(class_names))
14
+ model.load_state_dict(torch.load("plant_disease_model.pth", map_location=device))
15
+ model = model.to(device)
16
+ model.eval()
17
+
18
+ # βœ… Image Transformations (Must match training settings)
19
+ transform = transforms.Compose([
20
+ transforms.Resize((128, 128)),
21
+ transforms.ToTensor(),
22
+ transforms.Normalize([0.5], [0.5])
23
+ ])
24
+
25
+ # βœ… Function to Make Predictions
26
+ def predict_image(image_path):
27
+ image = Image.open(image_path).convert("RGB")
28
+ image = transform(image).unsqueeze(0).to(device)
29
+
30
+ with torch.no_grad():
31
+ output = model(image)
32
+ predicted_class = torch.argmax(output, dim=1).item()
33
+
34
+ return class_names[predicted_class]
35
+
36
+ # βœ… Test the model (optional)
37
+ if __name__ == "__main__":
38
+ sample_image = "test_image.jpg" # Replace with an actual image path
39
+ prediction = predict_image(sample_image)
40
+ print(f"Predicted Class: {prediction}")
plant_disease_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:892a520120fdb7827ff6a507639de765df7a1de817b9b56b140349977c802ed5
3
+ size 9221158
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ Pillow
4
+ tqdm
5
+ streamlit