Saketh12345 commited on
Commit
4c6bdbd
·
1 Parent(s): 174ae28

chore: remove unused app_resnet9.py

Browse files
Files changed (1) hide show
  1. app_resnet9.py +0 -240
app_resnet9.py DELETED
@@ -1,240 +0,0 @@
1
- import streamlit as st
2
- import torch
3
- import torch.nn as nn
4
- import torchvision.transforms as transforms
5
- from PIL import Image
6
- import json
7
- import numpy as np
8
- from pathlib import Path
9
-
10
- # Page config
11
- st.set_page_config(
12
- page_title="Plant Disease Classifier",
13
- page_icon="🌱",
14
- layout="wide"
15
- )
16
-
17
- # Custom CSS
18
- st.markdown("""
19
- <style>
20
- .main {
21
- max-width: 1000px;
22
- padding: 2rem;
23
- }
24
- .title {
25
- text-align: center;
26
- color: #2e8b57;
27
- }
28
- .prediction {
29
- font-size: 1.2rem;
30
- padding: 1rem;
31
- border-radius: 0.5rem;
32
- margin-top: 1rem;
33
- }
34
- .healthy {
35
- background-color: #d4edda;
36
- color: #155724;
37
- }
38
- .diseased {
39
- background-color: #f8d7da;
40
- color: #721c24;
41
- }
42
- </style>
43
- """, unsafe_allow_html=True)
44
-
45
- # Model class (same as in training)
46
- class ConvBlock(nn.Module):
47
- def __init__(self, in_channels, out_channels, pool=False):
48
- super().__init__()
49
- layers = [
50
- nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
51
- nn.BatchNorm2d(out_channels),
52
- nn.ReLU(inplace=True)
53
- ]
54
- if pool:
55
- layers.append(nn.MaxPool2d(2))
56
- self.conv = nn.Sequential(*layers)
57
-
58
- def forward(self, x):
59
- return self.conv(x)
60
-
61
- class ResNet9(nn.Module):
62
- def __init__(self, in_channels, num_classes):
63
- super().__init__()
64
-
65
- # First conv block
66
- self.features = nn.Sequential(
67
- # Conv1
68
- nn.Conv2d(in_channels, 64, kernel_size=3, stride=1, padding=1, bias=False),
69
- nn.BatchNorm2d(64),
70
- nn.ReLU(inplace=True),
71
-
72
- # Conv2
73
- nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=False),
74
- nn.BatchNorm2d(128),
75
- nn.ReLU(inplace=True),
76
- nn.MaxPool2d(kernel_size=2, stride=2),
77
-
78
- # Res1
79
- nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=False),
80
- nn.BatchNorm2d(128),
81
- nn.ReLU(inplace=True),
82
- nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=False),
83
- nn.BatchNorm2d(128),
84
- nn.ReLU(inplace=True),
85
-
86
- # Conv3
87
- nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=False),
88
- nn.BatchNorm2d(256),
89
- nn.ReLU(inplace=True),
90
- nn.MaxPool2d(kernel_size=2, stride=2),
91
-
92
- # Conv4
93
- nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=False),
94
- nn.BatchNorm2d(512),
95
- nn.ReLU(inplace=True),
96
- nn.MaxPool2d(kernel_size=2, stride=2),
97
-
98
- # Res2
99
- nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=False),
100
- nn.BatchNorm2d(512),
101
- nn.ReLU(inplace=True),
102
- nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=False),
103
- nn.BatchNorm2d(512),
104
- nn.ReLU(inplace=True)
105
- )
106
-
107
- self.classifier = nn.Sequential(
108
- nn.AdaptiveAvgPool2d(1),
109
- nn.Flatten(),
110
- nn.Dropout(0.2),
111
- nn.Linear(512, 256),
112
- nn.ReLU(inplace=True),
113
- nn.Linear(256, num_classes)
114
- )
115
-
116
- def forward(self, x):
117
- x = self.features(x)
118
- x = self.classifier(x)
119
- return x
120
-
121
- # Load class indices
122
- @st.cache_data
123
- def load_class_indices():
124
- with open('class_indices.json', 'r') as f:
125
- return json.load(f)
126
-
127
- # Load model
128
- @st.cache_resource
129
- def load_model():
130
- class_indices = load_class_indices()
131
- model = ResNet9(3, len(class_indices))
132
- model.load_state_dict(torch.load('plant_disease_model.pth', map_location=torch.device('cpu')))
133
- model.eval()
134
- return model
135
-
136
- # Preprocess image
137
- def preprocess_image(image):
138
- transform = transforms.Compose([
139
- transforms.Resize((256, 256)),
140
- transforms.ToTensor(),
141
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
142
- ])
143
- return transform(image).unsqueeze(0)
144
-
145
- # Predict function
146
- def predict(image, model, class_indices):
147
- idx_to_class = {int(k): v for k, v in class_indices.items()}
148
-
149
- # Preprocess
150
- input_tensor = preprocess_image(image)
151
-
152
- # Predict
153
- with torch.no_grad():
154
- output = model(input_tensor)
155
- probabilities = torch.nn.functional.softmax(output[0], dim=0)
156
- confidence, predicted_idx = torch.max(probabilities, 0)
157
- predicted_class = idx_to_class[predicted_idx.item()]
158
-
159
- return predicted_class, confidence.item()
160
-
161
- # Main app
162
- def main():
163
- st.title("🌱 Plant Disease Classifier")
164
- st.markdown("---")
165
-
166
- # Load model and class indices
167
- try:
168
- model = load_model()
169
- class_indices = load_class_indices()
170
- idx_to_class = {int(k): v for k, v in class_indices.items()}
171
- except Exception as e:
172
- st.error(f"Error loading model: {str(e)}")
173
- st.info("Please make sure you have trained the model first by running 'python resnet9_train.py'")
174
- return
175
-
176
- # File uploader
177
- uploaded_file = st.file_uploader("Upload an image of a plant leaf", type=["jpg", "jpeg", "png"])
178
-
179
- if uploaded_file is not None:
180
- # Display image
181
- image = Image.open(uploaded_file).convert('RGB')
182
- st.image(image, caption='Uploaded Image', use_column_width=True)
183
-
184
- # Make prediction
185
- with st.spinner('Analyzing...'):
186
- predicted_class, confidence = predict(image, model, class_indices)
187
-
188
- # Display result
189
- plant, status = predicted_class.split('___')
190
- is_healthy = status == 'healthy'
191
-
192
- st.markdown("### Prediction Result")
193
- col1, col2 = st.columns(2)
194
-
195
- with col1:
196
- st.metric("Plant", plant.replace('_', ' ').title())
197
- with col2:
198
- status_display = "Healthy 🟢" if is_healthy else "Diseased 🔴"
199
- st.metric("Status", status_display)
200
-
201
- if not is_healthy:
202
- st.metric("Disease", status.replace('_', ' ').title())
203
-
204
- st.metric("Confidence", f"{confidence*100:.2f}%")
205
-
206
- # Show info based on prediction
207
- if is_healthy:
208
- st.success(f"This {plant.replace('_', ' ').lower()} leaf appears to be healthy!")
209
- else:
210
- st.warning(f"This {plant.replace('_', ' ').lower()} leaf shows signs of {status.replace('_', ' ').lower()}.")
211
-
212
- # Add some general advice (you can expand this)
213
- st.info("""
214
- **Recommendations:**
215
- - Isolate the affected plant to prevent spread
216
- - Remove severely infected leaves
217
- - Consider using appropriate fungicides/pesticides
218
- - Ensure proper spacing and air circulation
219
- - Maintain optimal watering practices
220
- """)
221
- else:
222
- st.info("Please upload an image of a plant leaf to check for diseases.")
223
-
224
- # Add some information about the model
225
- st.markdown("---")
226
- st.markdown("""
227
- ### About this App
228
- This app uses a ResNet9 deep learning model to identify plant diseases from leaf images.
229
- It can detect 38 different classes of plant diseases across 14 plant species.
230
-
231
- **How to use:**
232
- 1. Upload an image of a plant leaf
233
- 2. The model will analyze the image
234
- 3. View the prediction and recommendations
235
-
236
- **Note:** For best results, use clear, well-lit photos of individual leaves.
237
- """)
238
-
239
- if __name__ == "__main__":
240
- main()