fatimaxa commited on
Commit
b68ebd5
·
verified ·
1 Parent(s): 12449dd

Upload 40 files

Browse files
Files changed (40) hide show
  1. .gitignore +249 -0
  2. __pycache__/data_prep.cpython-311.pyc +0 -0
  3. app.py +41 -0
  4. data/plant_village/DScomp380___plant_village/default/0.0.0/5ce680f815ea9fab7b6f8346ae4c71e7099696a5/cache-4615ab977727fc47.arrow +3 -0
  5. data/plant_village/DScomp380___plant_village/default/0.0.0/5ce680f815ea9fab7b6f8346ae4c71e7099696a5/cache-64f7f66e875a2297.arrow +3 -0
  6. data/plant_village/DScomp380___plant_village/default/0.0.0/5ce680f815ea9fab7b6f8346ae4c71e7099696a5/cache-a711edc7192ef3fb.arrow +3 -0
  7. data/plant_village/DScomp380___plant_village/default/0.0.0/5ce680f815ea9fab7b6f8346ae4c71e7099696a5/cache-f446c767d80f0d9a.arrow +3 -0
  8. data/plant_village/DScomp380___plant_village/default/0.0.0/5ce680f815ea9fab7b6f8346ae4c71e7099696a5/dataset_info.json +1 -0
  9. data/plant_village/DScomp380___plant_village/default/0.0.0/5ce680f815ea9fab7b6f8346ae4c71e7099696a5/plant_village-train-00000-of-00002.arrow +3 -0
  10. data/plant_village/DScomp380___plant_village/default/0.0.0/5ce680f815ea9fab7b6f8346ae4c71e7099696a5/plant_village-train-00001-of-00002.arrow +3 -0
  11. models/__pycache__/model.cpython-311.pyc +0 -0
  12. models/__pycache__/model.cpython-313.pyc +0 -0
  13. models/model.py +69 -0
  14. requirements.txt +0 -0
  15. saved_models/plant_cnn.pt +3 -0
  16. tabs/__pycache__/batch_processing.cpython-311.pyc +0 -0
  17. tabs/__pycache__/single_prediction.cpython-311.pyc +0 -0
  18. tabs/batch_processing.py +46 -0
  19. tabs/single_prediction.py +77 -0
  20. ui_text/about.md +25 -0
  21. ui_text/class_names.json +41 -0
  22. ui_text/disease_info.json +9 -0
  23. ui_text/examples/Apple___Apple_scab.jpg +0 -0
  24. ui_text/examples/Soybean___healthy.jpg +0 -0
  25. ui_text/examples/Tomato___Bacterial_spot.jpg +0 -0
  26. ui_text/examples/Tomato___Septoria_leaf_spot.jpg +0 -0
  27. ui_text/examples/Tomato___Tomato_Yellow_Leaf_Curl_Virus.jpg +0 -0
  28. ui_text/examples/Tomato___healthy.jpg +0 -0
  29. ui_text/intro.md +5 -0
  30. utils/__pycache__/chart_vis.cpython-311.pyc +0 -0
  31. utils/__pycache__/config.cpython-311.pyc +0 -0
  32. utils/__pycache__/config.cpython-313.pyc +0 -0
  33. utils/__pycache__/model_loader.cpython-311.pyc +0 -0
  34. utils/__pycache__/predictions.cpython-311.pyc +0 -0
  35. utils/__pycache__/vis.cpython-311.pyc +0 -0
  36. utils/chart_vis.py +31 -0
  37. utils/config.py +6 -0
  38. utils/model_loader.py +51 -0
  39. utils/predictions.py +58 -0
  40. utils/vis.py +100 -0
.gitignore ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .gap_gpu_env/
2
+ saved_models/
3
+ __pychache__/
4
+ *.py[cod]
5
+
6
+
7
+ # Created by https://www.toptal.com/developers/gitignore/api/venv,macos,python,visualstudiocode
8
+ # Edit at https://www.toptal.com/developers/gitignore?templates=venv,macos,python,visualstudiocode
9
+
10
+ ### macOS ###
11
+ # General
12
+ .DS_Store
13
+ .AppleDouble
14
+ .LSOverride
15
+
16
+ # Icon must end with two \r
17
+ Icon
18
+
19
+
20
+ # Thumbnails
21
+ ._*
22
+
23
+ # Files that might appear in the root of a volume
24
+ .DocumentRevisions-V100
25
+ .fseventsd
26
+ .Spotlight-V100
27
+ .TemporaryItems
28
+ .Trashes
29
+ .VolumeIcon.icns
30
+ .com.apple.timemachine.donotpresent
31
+
32
+ # Directories potentially created on remote AFP share
33
+ .AppleDB
34
+ .AppleDesktop
35
+ Network Trash Folder
36
+ Temporary Items
37
+ .apdisk
38
+
39
+ ### macOS Patch ###
40
+ # iCloud generated files
41
+ *.icloud
42
+
43
+ ### Python ###
44
+ # Byte-compiled / optimized / DLL files
45
+ __pycache__/
46
+ *.py[cod]
47
+ *$py.class
48
+
49
+ # C extensions
50
+ *.so
51
+
52
+ # Distribution / packaging
53
+ .Python
54
+ build/
55
+ develop-eggs/
56
+ dist/
57
+ downloads/
58
+ eggs/
59
+ .eggs/
60
+ lib/
61
+ lib64/
62
+ parts/
63
+ sdist/
64
+ var/
65
+ wheels/
66
+ share/python-wheels/
67
+ *.egg-info/
68
+ .installed.cfg
69
+ *.egg
70
+ MANIFEST
71
+
72
+ # PyInstaller
73
+ # Usually these files are written by a python script from a template
74
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
75
+ *.manifest
76
+ *.spec
77
+
78
+ # Installer logs
79
+ pip-log.txt
80
+ pip-delete-this-directory.txt
81
+
82
+ # Unit test / coverage reports
83
+ htmlcov/
84
+ .tox/
85
+ .nox/
86
+ .coverage
87
+ .coverage.*
88
+ .cache
89
+ nosetests.xml
90
+ coverage.xml
91
+ *.cover
92
+ *.py,cover
93
+ .hypothesis/
94
+ .pytest_cache/
95
+ cover/
96
+
97
+ # Translations
98
+ *.mo
99
+ *.pot
100
+
101
+ # Django stuff:
102
+ *.log
103
+ local_settings.py
104
+ db.sqlite3
105
+ db.sqlite3-journal
106
+
107
+ # Flask stuff:
108
+ instance/
109
+ .webassets-cache
110
+
111
+ # Scrapy stuff:
112
+ .scrapy
113
+
114
+ # Sphinx documentation
115
+ docs/_build/
116
+
117
+ # PyBuilder
118
+ .pybuilder/
119
+ target/
120
+
121
+ # Jupyter Notebook
122
+ .ipynb_checkpoints
123
+
124
+ # IPython
125
+ profile_default/
126
+ ipython_config.py
127
+
128
+ # pyenv
129
+ # For a library or package, you might want to ignore these files since the code is
130
+ # intended to run in multiple environments; otherwise, check them in:
131
+ # .python-version
132
+
133
+ # pipenv
134
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
135
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
136
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
137
+ # install all needed dependencies.
138
+ #Pipfile.lock
139
+
140
+ # poetry
141
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
142
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
143
+ # commonly ignored for libraries.
144
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
145
+ #poetry.lock
146
+
147
+ # pdm
148
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
149
+ #pdm.lock
150
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
151
+ # in version control.
152
+ # https://pdm.fming.dev/#use-with-ide
153
+ .pdm.toml
154
+
155
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
156
+ __pypackages__/
157
+
158
+ # Celery stuff
159
+ celerybeat-schedule
160
+ celerybeat.pid
161
+
162
+ # SageMath parsed files
163
+ *.sage.py
164
+
165
+ # Environments
166
+ .env
167
+ .venv
168
+ env/
169
+ venv/
170
+ ENV/
171
+ env.bak/
172
+ venv.bak/
173
+
174
+ # Spyder project settings
175
+ .spyderproject
176
+ .spyproject
177
+
178
+ # Rope project settings
179
+ .ropeproject
180
+
181
+ # mkdocs documentation
182
+ /site
183
+
184
+ # mypy
185
+ .mypy_cache/
186
+ .dmypy.json
187
+ dmypy.json
188
+
189
+ # Pyre type checker
190
+ .pyre/
191
+
192
+ # pytype static type analyzer
193
+ .pytype/
194
+
195
+ # Cython debug symbols
196
+ cython_debug/
197
+
198
+ # PyCharm
199
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
200
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
201
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
202
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
203
+ #.idea/
204
+
205
+ ### Python Patch ###
206
+ # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
207
+ poetry.toml
208
+
209
+ # ruff
210
+ .ruff_cache/
211
+
212
+ # LSP config files
213
+ pyrightconfig.json
214
+
215
+ ### venv ###
216
+ # Virtualenv
217
+ # http://iamzed.com/2009/05/07/a-primer-on-virtualenv/
218
+ [Bb]in
219
+ [Ii]nclude
220
+ [Ll]ib
221
+ [Ll]ib64
222
+ [Ll]ocal
223
+ [Ss]cripts
224
+ pyvenv.cfg
225
+ pip-selfcheck.json
226
+
227
+ ### VisualStudioCode ###
228
+ .vscode/*
229
+ !.vscode/settings.json
230
+ !.vscode/tasks.json
231
+ !.vscode/launch.json
232
+ !.vscode/extensions.json
233
+ !.vscode/*.code-snippets
234
+
235
+ # Local History for Visual Studio Code
236
+ .history/
237
+
238
+ # Built Visual Studio Code Extensions
239
+ *.vsix
240
+
241
+ ### VisualStudioCode Patch ###
242
+ # Ignore all local history of files
243
+ .history
244
+ .ionide
245
+
246
+ # End of https://www.toptal.com/developers/gitignore/api/venv,macos,python,visualstudiocode
247
+
248
+ data/
249
+ *.arrow
__pycache__/data_prep.cpython-311.pyc ADDED
Binary file (4.62 kB). View file
 
app.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from utils.model_loader import load_model_and_config, load_ui_text
3
+ from tabs.single_prediction import create_single_prediction_tab
4
+ from tabs.batch_processing import create_batch_processing_tab
5
+
6
+ # loadign model
7
+ config = load_model_and_config()
8
+ intro_md, about_md = load_ui_text()
9
+
10
+ model = config['model']
11
+ class_names = config['class_names']
12
+ disease_db = config['disease_db']
13
+ device = config['device']
14
+
15
+
16
+ # Custom CSS
17
+ custom_css = """
18
+ .gradio-container {
19
+ font-family: 'Arial', sans-serif;
20
+ }
21
+ .output-class {
22
+ font-size: 16px;
23
+ }
24
+ """
25
+
26
+ # Create Gradio Interface
27
+ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
28
+
29
+ gr.Markdown(intro_md)
30
+
31
+ with gr.Tab("Single Image Prediction"):
32
+ create_single_prediction_tab(model, class_names, disease_db, device)
33
+
34
+ with gr.Tab("Batch Processing"):
35
+ create_batch_processing_tab(model, class_names, device)
36
+
37
+ with gr.Tab("About"):
38
+ gr.Markdown(about_md)
39
+
40
+ if __name__ == "__main__":
41
+ demo.launch(share=False)
data/plant_village/DScomp380___plant_village/default/0.0.0/5ce680f815ea9fab7b6f8346ae4c71e7099696a5/cache-4615ab977727fc47.arrow ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e87aa03a7d726658b12bfcfc4c8faea5fc24dc6cf75123d005b1443c6a7ca79c
3
+ size 68144
data/plant_village/DScomp380___plant_village/default/0.0.0/5ce680f815ea9fab7b6f8346ae4c71e7099696a5/cache-64f7f66e875a2297.arrow ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:94c680bcbb0a710782f6133a110f1eecdbdade266786f4726cdaa94b0189182c
3
+ size 135832
data/plant_village/DScomp380___plant_village/default/0.0.0/5ce680f815ea9fab7b6f8346ae4c71e7099696a5/cache-a711edc7192ef3fb.arrow ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:52bd1f2a2bff2a96a8ea9d96739bb7814a193bc4c13a9894f4768c5ef335da3a
3
+ size 316416
data/plant_village/DScomp380___plant_village/default/0.0.0/5ce680f815ea9fab7b6f8346ae4c71e7099696a5/cache-f446c767d80f0d9a.arrow ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:60305644da641ec434b5511f370be2e75bea01809259c8d10ffb624dfa57c15d
3
+ size 68136
data/plant_village/DScomp380___plant_village/default/0.0.0/5ce680f815ea9fab7b6f8346ae4c71e7099696a5/dataset_info.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"description": "", "citation": "", "homepage": "", "license": "", "features": {"image": {"_type": "Image"}, "label": {"names": ["Apple___Apple_scab", "Apple___Black_rot", "Apple___Cedar_apple_rust", "Apple___healthy", "Background_without_leaves", "Blueberry___healthy", "Cherry___Powdery_mildew", "Cherry___healthy", "Corn___Cercospora_leaf_spot Gray_leaf_spot", "Corn___Common_rust", "Corn___Northern_Leaf_Blight", "Corn___healthy", "Grape___Black_rot", "Grape___Esca_(Black_Measles)", "Grape___Leaf_blight_(Isariopsis_Leaf_Spot)", "Grape___healthy", "Orange___Haunglongbing_(Citrus_greening)", "Peach___Bacterial_spot", "Peach___healthy", "Pepper,_bell___Bacterial_spot", "Pepper,_bell___healthy", "Potato___Early_blight", "Potato___Late_blight", "Potato___healthy", "Raspberry___healthy", "Soybean___healthy", "Squash___Powdery_mildew", "Strawberry___Leaf_scorch", "Strawberry___healthy", "Tomato___Bacterial_spot", "Tomato___Early_blight", "Tomato___Late_blight", "Tomato___Leaf_Mold", "Tomato___Septoria_leaf_spot", "Tomato___Spider_mites Two-spotted_spider_mite", "Tomato___Target_Spot", "Tomato___Tomato_Yellow_Leaf_Curl_Virus", "Tomato___Tomato_mosaic_virus", "Tomato___healthy"], "_type": "ClassLabel"}}, "builder_name": "parquet", "dataset_name": "plant_village", "config_name": "default", "version": {"version_str": "0.0.0", "major": 0, "minor": 0, "patch": 0}, "splits": {"train": {"name": "train", "num_bytes": 863151201, "num_examples": 55447, "shard_lengths": [33224, 22223], "dataset_name": "plant_village"}}, "download_checksums": {"hf://datasets/DScomp380/plant_village@5ce680f815ea9fab7b6f8346ae4c71e7099696a5/data/train-00000-of-00002.parquet": {"num_bytes": 400759198, "checksum": null}, "hf://datasets/DScomp380/plant_village@5ce680f815ea9fab7b6f8346ae4c71e7099696a5/data/train-00001-of-00002.parquet": {"num_bytes": 459968278, "checksum": null}}, "download_size": 860727476, "dataset_size": 863151201, "size_in_bytes": 1723878677}
data/plant_village/DScomp380___plant_village/default/0.0.0/5ce680f815ea9fab7b6f8346ae4c71e7099696a5/plant_village-train-00000-of-00002.arrow ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e9139be0a0731bd3e741d83fcc7ca3f8a892fffb929f8edba4f365537aa67f03
3
+ size 500987424
data/plant_village/DScomp380___plant_village/default/0.0.0/5ce680f815ea9fab7b6f8346ae4c71e7099696a5/plant_village-train-00001-of-00002.arrow ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a0fcda5b77e896d97a0313c78271a9d6ae632fe61db304a8b97dfe85dff9197d
3
+ size 362342232
models/__pycache__/model.cpython-311.pyc ADDED
Binary file (4.21 kB). View file
 
models/__pycache__/model.cpython-313.pyc ADDED
Binary file (3.78 kB). View file
 
models/model.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class ConvBlock(nn.Module):
5
+ def __init__(self, in_channels:int, out_channels:int) :
6
+ super().__init__()
7
+
8
+ # first convolutional layer
9
+ self.conv_1 = nn.Conv2d(in_channels, out_channels,
10
+ kernel_size=3, stride=1,
11
+ padding=1, bias=False)
12
+ self.batch_norm_1 = nn.BatchNorm2d(num_features=out_channels)
13
+
14
+ # second convolutional layer
15
+ self.conv_2 = nn.Conv2d(out_channels, out_channels,
16
+ kernel_size=3, stride=1,
17
+ padding=1, bias=False)
18
+ self.batch_norm_2 = nn.BatchNorm2d(num_features=out_channels)
19
+
20
+ self.activation = nn.ReLU(inplace=True)
21
+
22
+ def forward(self, x):
23
+ # basic conv -> bn -> relu forward pass
24
+ output = self.activation(self.batch_norm_1(self.conv_1(x)))
25
+ output = self.activation(self.batch_norm_2(self.conv_2(output)))
26
+ return output
27
+
28
+ class PlantCNN(nn.Module):
29
+ def __init__(self, num_classes:int, channels, dropout: float):
30
+ super().__init__()
31
+
32
+ # entry block to map RGB -> 64 channels
33
+ first_c = channels[0]
34
+ self.input_block = nn.Sequential(
35
+ nn.Conv2d(in_channels=3, out_channels=first_c,
36
+ kernel_size=3, stride=1,
37
+ padding=1, bias=False),
38
+ nn.BatchNorm2d(num_features=first_c),
39
+ nn.ReLU(inplace=True)
40
+ )
41
+
42
+ self.stages = nn.ModuleList()
43
+ in_c = first_c
44
+ for c in channels:
45
+ stage = nn.Sequential(
46
+ ConvBlock(in_c,c),
47
+ ConvBlock(c,c),
48
+ nn.MaxPool2d(kernel_size=2)
49
+ )
50
+ self.stages.append(stage)
51
+ in_c = c
52
+
53
+ # final pooling + classifer
54
+ self.pool = nn.AdaptiveAvgPool2d(1)
55
+ self.dropout = nn.Dropout(dropout)
56
+ self.fc = nn.Linear(channels[-1], num_classes) #change for app.py
57
+
58
+ def forward(self, x):
59
+ output = self.input_block(x)
60
+ # pass through each stage in order
61
+ for stage in self.stages:
62
+ output = stage(output)
63
+
64
+ # pool to (batch, 512) then flatten
65
+ output = torch.flatten(self.pool(output), 1)
66
+ output = self.fc(self.dropout(output))
67
+ return output
68
+
69
+
requirements.txt ADDED
Binary file (218 Bytes). View file
 
saved_models/plant_cnn.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4ec7e8ff4511d002f97f9c2805e02a1a9b2900a54c9c11ed58e5d07e3e82fb9e
3
+ size 44127290
tabs/__pycache__/batch_processing.cpython-311.pyc ADDED
Binary file (2.38 kB). View file
 
tabs/__pycache__/single_prediction.cpython-311.pyc ADDED
Binary file (4.32 kB). View file
 
tabs/batch_processing.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from utils.predictions import predict_batch_visual
3
+
4
+ # this was written with the help of AI
5
+ def create_batch_processing_tab(model, class_names, device):
6
+ """Create the batch processing tab"""
7
+
8
+ def batch_predict_wrapper(files, progress=gr.Progress()):
9
+ if not files:
10
+ return None, " No files uploaded! Please select images to process."
11
+
12
+ progress(0, desc="Starting batch processing...")
13
+ gallery, results = predict_batch_visual(files, model, class_names, device, progress)
14
+ return gallery, results
15
+
16
+ gr.Markdown("""
17
+ ### Upload multiple images for batch prediction
18
+
19
+ Upload several plant leaf images at once to get predictions for all of them.
20
+ Results will show each image with its prediction and confidence score.
21
+ """)
22
+
23
+ batch_input = gr.File(
24
+ file_count="multiple",
25
+ label="Upload Multiple Images",
26
+ file_types=["image"]
27
+ )
28
+ batch_btn = gr.Button("Process All Images", variant="primary", size="lg")
29
+
30
+ with gr.Row():
31
+ batch_gallery = gr.Gallery(
32
+ label="Processed Images",
33
+ columns=3,
34
+ height="auto",
35
+ object_fit="contain"
36
+ )
37
+
38
+ batch_output = gr.Markdown(label="Detailed Results")
39
+
40
+ batch_btn.click(
41
+ fn=batch_predict_wrapper,
42
+ inputs=batch_input,
43
+ outputs=[batch_gallery, batch_output]
44
+ )
45
+
46
+ return batch_input, batch_btn, batch_gallery, batch_output
tabs/single_prediction.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from utils.predictions import predict_single_image, get_disease_info
3
+ from utils.chart_vis import create_prediction_plot
4
+
5
+ # this was written with the help of AI
6
+ def create_single_prediction_tab(model, class_names, disease_db, device):
7
+ """Create the single image prediction tab"""
8
+
9
+ def predict_with_visualization(image, show_top_n):
10
+ """Prediction function with all outputs"""
11
+ if image is None:
12
+ return None, "Please upload an image", None
13
+
14
+ # Make prediction
15
+ top_preds = predict_single_image(image, model, class_names, device, show_top_n)
16
+
17
+ # Create visualization
18
+ plot = create_prediction_plot(top_preds)
19
+
20
+ # Get disease info
21
+ top_disease = top_preds[0][0]
22
+ confidence = top_preds[0][1]
23
+
24
+ info_text = f"## Top Prediction: {top_disease}\n"
25
+ info_text += f"**Confidence:** {confidence:.2%}\n\n"
26
+ info_text += f"{get_disease_info(top_disease, disease_db)}\n\n"
27
+
28
+ if confidence < 0.5:
29
+ info_text += "**Note:** Low confidence. Consider expert verification."
30
+
31
+ # Results dictionary
32
+ results_dict = {label: round(float(prob), 4) for label, prob in top_preds}
33
+
34
+ return plot, info_text, results_dict
35
+
36
+ examples = [
37
+ ["ui_text/examples/Apple___Apple_scab.jpg"],
38
+ ["ui_text/examples/Tomato___healthy.jpg"],
39
+ ["ui_text/examples/Tomato___Bacterial_spot.jpg"],
40
+ ["ui_text/examples/Tomato___Tomato_Yellow_Leaf_Curl_Virus.jpg"],
41
+ ["ui_text/examples/Tomato___Septoria_leaf_spot.jpg"],
42
+ ["ui_text/examples/Soybean___healthy.jpg"]
43
+ ]
44
+
45
+ with gr.Row():
46
+ with gr.Column(scale=1):
47
+ input_image = gr.Image(type="pil", label="Upload Plant Leaf Image")
48
+ top_n_slider = gr.Slider(
49
+ minimum=3,
50
+ maximum=15,
51
+ value=10,
52
+ step=1,
53
+ label="Number of top predictions to show"
54
+ )
55
+ predict_btn = gr.Button("Analyze Disease", variant="primary")
56
+
57
+ gr.Markdown("### Example Images")
58
+ gr.Examples(
59
+ examples=examples,
60
+ inputs=input_image,
61
+ label="Click an example to try it out",
62
+ cache_examples=False
63
+ )
64
+
65
+ with gr.Column(scale=1):
66
+ output_plot = gr.Plot(label="Prediction Confidence Chart")
67
+ output_info = gr.Markdown(label="Disease Information")
68
+ output_label = gr.Label(label="Detailed Predictions", num_top_classes=10)
69
+
70
+ # Connect button
71
+ predict_btn.click(
72
+ fn=predict_with_visualization,
73
+ inputs=[input_image, top_n_slider],
74
+ outputs=[output_plot, output_info, output_label]
75
+ )
76
+
77
+ return input_image, top_n_slider, predict_btn
ui_text/about.md ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ## Model Information
3
+
4
+ **Architecture:** Custom PlantCNN
5
+ - **Channels:** [96, 192, 384, 768]
6
+ - **Number of Classes:** 39
7
+ - **Input Size:** 224x224 pixels
8
+ - **Framework:** PyTorch
9
+
10
+ ## Features
11
+
12
+ - Real-time disease detection
13
+ - Confidence visualization with histogram
14
+ - Top-N predictions customization
15
+ - Batch processing support
16
+ - Pre-loaded example gallery
17
+ - Disease information and treatment suggestions
18
+
19
+ ## How to Use
20
+
21
+ 1. Upload a clear image of a plant leaf
22
+ 2. Adjust the number of predictions if needed
23
+ 3. Click "Analyze Disease" to get results
24
+ 4. Review the confidence chart and recommendations
25
+
ui_text/class_names.json ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ "Apple___Apple_scab",
3
+ "Apple___Black_rot",
4
+ "Apple___Cedar_apple_rust",
5
+ "Apple___healthy",
6
+ "Background_without_leaves",
7
+ "Blueberry___healthy",
8
+ "Cherry___Powdery_mildew",
9
+ "Cherry___healthy",
10
+ "Corn___Cercospora_leaf_spot_Gray_leaf_spot",
11
+ "Corn___Common_rust",
12
+ "Corn___Northern_Leaf_Blight",
13
+ "Corn___healthy",
14
+ "Grape___Black_rot",
15
+ "Grape___Esca_(Black_Measles)",
16
+ "Grape___Leaf_blight_(Isariopsis_Leaf_Spot)",
17
+ "Grape___healthy",
18
+ "Orange___Haunglongbing_(Citrus_greening)",
19
+ "Peach___Bacterial_spot",
20
+ "Peach___healthy",
21
+ "Pepper,_bell___Bacterial_spot",
22
+ "Pepper,_bell___healthy",
23
+ "Potato___Early_blight",
24
+ "Potato___Late_blight",
25
+ "Potato___healthy",
26
+ "Raspberry___healthy",
27
+ "Soybean___healthy",
28
+ "Squash___Powdery_mildew",
29
+ "Strawberry___Leaf_scorch",
30
+ "Strawberry___healthy",
31
+ "Tomato___Bacterial_spot",
32
+ "Tomato___Early_blight",
33
+ "Tomato___Late_blight",
34
+ "Tomato___Leaf_Mold",
35
+ "Tomato___Septoria_leaf_spot",
36
+ "Tomato___Spider_mites_Two-spotted_spider_mite",
37
+ "Tomato___Target_Spot",
38
+ "Tomato___Tomato_Yellow_Leaf_Curl_Virus",
39
+ "Tomato___Tomato_mosaic_virus",
40
+ "Tomato___healthy"
41
+ ]
ui_text/disease_info.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "healthy": "✅ No disease detected! The plant appears healthy.",
3
+ "scab": "Apple Scab: A fungal disease causing dark, scabby lesions. Treatment: Apply fungicides and remove infected leaves.",
4
+ "rust": "Rust: Fungal infection with orange/brown pustules. Treatment: Use fungicides and improve air circulation.",
5
+ "spot": "Leaf Spot: Bacterial or fungal spots on leaves. Treatment: Remove infected parts and apply appropriate treatment.",
6
+ "blight": "Blight: A severe disease causing rapid plant death. Treatment: Remove infected plants and use resistant varieties.",
7
+ "mold": "Mold: Fungal growth on leaves. Treatment: Improve ventilation and reduce humidity.",
8
+ "virus": "Viral disease: Transmitted by insects. Treatment: Remove infected plants and control insect vectors."
9
+ }
ui_text/examples/Apple___Apple_scab.jpg ADDED
ui_text/examples/Soybean___healthy.jpg ADDED
ui_text/examples/Tomato___Bacterial_spot.jpg ADDED
ui_text/examples/Tomato___Septoria_leaf_spot.jpg ADDED
ui_text/examples/Tomato___Tomato_Yellow_Leaf_Curl_Virus.jpg ADDED
ui_text/examples/Tomato___healthy.jpg ADDED
ui_text/intro.md ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+
2
+ # Plant Disease Detection System
3
+
4
+ Upload an image of a plant leaf to detect diseases using our trained CNN model.
5
+ The model can identify **39 different plant diseases and healthy conditions**.
utils/__pycache__/chart_vis.cpython-311.pyc ADDED
Binary file (2.38 kB). View file
 
utils/__pycache__/config.cpython-311.pyc ADDED
Binary file (639 Bytes). View file
 
utils/__pycache__/config.cpython-313.pyc ADDED
Binary file (477 Bytes). View file
 
utils/__pycache__/model_loader.cpython-311.pyc ADDED
Binary file (2.95 kB). View file
 
utils/__pycache__/predictions.cpython-311.pyc ADDED
Binary file (7.79 kB). View file
 
utils/__pycache__/vis.cpython-311.pyc ADDED
Binary file (7.07 kB). View file
 
utils/chart_vis.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import numpy as np
3
+
4
+ def create_prediction_plot(top_preds):
5
+ """Create a horizontal bar chart of predictions"""
6
+ labels = [label for label, _ in top_preds]
7
+ probs = [prob for _, prob in top_preds]
8
+
9
+ # figure
10
+ fig, ax = plt.subplots(figsize=(10, 6))
11
+
12
+ # horizontal bar chart
13
+ y_pos = np.arange(len(labels))
14
+ colors = plt.cm.RdYlGn(np.array(probs)) # Color based on confidence AI idea
15
+
16
+ ax.barh(y_pos, probs, color=colors, alpha=0.8)
17
+ ax.set_yticks(y_pos)
18
+ ax.set_yticklabels(labels)
19
+ ax.invert_yaxis() # Top prediction at the top
20
+ ax.set_xlabel('Confidence Score', fontsize=12)
21
+ ax.set_title('Top Disease Predictions', fontsize=14, fontweight='bold')
22
+ ax.set_xlim([0, 1])
23
+
24
+ # Add value labels on bars
25
+ for i, (label, prob) in enumerate(zip(labels, probs)):
26
+ ax.text(prob + 0.01, i, f'{prob:.3f}',
27
+ va='center', fontsize=10)
28
+
29
+ plt.tight_layout()
30
+ return fig
31
+
utils/config.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import yaml
2
+
3
+ # helper function to load configs from yaml file
4
+ def load_config(path="config.yaml"):
5
+ with open(path, "r") as f:
6
+ return yaml.safe_load(f)
utils/model_loader.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import json
3
+ import os
4
+ from models.model import PlantCNN
5
+
6
+ def load_model_and_config():
7
+ """Load the trained model and all configuration files"""
8
+
9
+ # Paths
10
+ MODEL_PATH = "saved_models/plant_cnn.pt"
11
+ CLASS_NAMES_PATH = "ui_text/class_names.json"
12
+
13
+ # Config
14
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+ CHANNELS = [96, 192, 384, 768]
16
+ DROPOUT = 0.4
17
+ NUM_CLASSES = 39
18
+
19
+ # Load class names
20
+ with open(CLASS_NAMES_PATH, "r") as f:
21
+ class_names = json.load(f)
22
+
23
+ # Load disease info
24
+ with open("ui_text/disease_info.json", "r", encoding="utf-8") as f:
25
+ disease_db = json.load(f)
26
+
27
+ # Load model
28
+ model = PlantCNN(num_classes=NUM_CLASSES, channels=CHANNELS, dropout=DROPOUT).to(DEVICE)
29
+ if os.path.exists(MODEL_PATH):
30
+ print("Loading trained model weights...")
31
+ model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
32
+ model.eval()
33
+ else:
34
+ exit()
35
+
36
+ return {
37
+ 'model': model,
38
+ 'class_names': class_names,
39
+ 'disease_db': disease_db,
40
+ 'device': DEVICE
41
+ }
42
+
43
+ def load_ui_text():
44
+ """Load intro and about markdown files"""
45
+ with open("ui_text/intro.md", "r", encoding="utf-8") as f:
46
+ intro_md = f.read()
47
+
48
+ with open("ui_text/about.md", "r", encoding="utf-8") as f:
49
+ about_md = f.read()
50
+
51
+ return intro_md, about_md
utils/predictions.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision.transforms as transforms
3
+ from PIL import Image
4
+ import os
5
+
6
+ # this was written with the help of AI
7
+ transform = transforms.Compose([
8
+ transforms.Resize((224, 224)),
9
+ transforms.ToTensor(),
10
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
11
+ ])
12
+
13
+ def get_disease_info(disease_name, disease_db):
14
+ """Get information about a detected disease"""
15
+ disease_lower = disease_name.lower()
16
+ for key, info in disease_db.items():
17
+ if key in disease_lower:
18
+ return info
19
+ return " Disease information not available for this classification."
20
+
21
+ def predict_single_image(image, model, class_names, device, show_top_n=10):
22
+ """Make prediction on a single image"""
23
+ if image is None:
24
+ return None
25
+
26
+ # Transform and predict
27
+ image_tensor = transform(image).unsqueeze(0).to(device)
28
+ with torch.no_grad():
29
+ outputs = model(image_tensor)
30
+ probs = torch.nn.functional.softmax(outputs, dim=1).cpu().squeeze().numpy()
31
+
32
+ # Get top predictions
33
+ all_preds = sorted(zip(class_names, probs), key=lambda x: x[1], reverse=True)
34
+ top_preds = all_preds[:show_top_n]
35
+
36
+ return top_preds
37
+
38
+ def predict_batch(files, model, class_names, device):
39
+ """Process multiple images at once"""
40
+ if not files:
41
+ return "No files uploaded"
42
+
43
+ results = []
44
+ for file in files:
45
+ image = Image.open(file.name).convert('RGB')
46
+ image_tensor = transform(image).unsqueeze(0).to(device)
47
+
48
+ with torch.no_grad():
49
+ outputs = model(image_tensor)
50
+ probs = torch.nn.functional.softmax(outputs, dim=1).cpu().squeeze().numpy()
51
+
52
+ top_pred = class_names[probs.argmax()]
53
+ confidence = probs.max()
54
+
55
+ results.append(f"**{os.path.basename(file.name)}**: {top_pred} ({confidence:.2%})")
56
+
57
+ return "\n\n".join(results)
58
+
utils/vis.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import matplotlib.pyplot as plt
3
+ from sklearn.metrics import confusion_matrix
4
+ import torch
5
+
6
+ def to_display_image(img_tensor, mean, std):
7
+ img = img_tensor.cpu().numpy()
8
+ for c in range(3):
9
+ img[c] = img[c]*std[c]+mean[c]
10
+ img = np.clip(img, 0.0, 1.0)
11
+ img = np.transpose(img, (1,2,0))
12
+ return img
13
+
14
+ def visualize_preds(images, labels, preds, logger, class_names, mean, std, num_images):
15
+ num_images = min(num_images, len(images))
16
+ rows = int(np.ceil(num_images/4))
17
+ fig, axs = plt.subplots(rows, 4, figsize=(24, 6*rows))
18
+ axs = axs.flatten()
19
+
20
+ for i, ax in enumerate(axs):
21
+ ax.axis("off")
22
+ if i >= len(images):
23
+ continue
24
+
25
+ img = to_display_image(images[i], mean, std)
26
+ lbl = labels[i]
27
+ pr = preds[i]
28
+
29
+ ax.imshow(img)
30
+ title = f"Label: {class_names[lbl]}\nPrediction: {class_names[pr]}"
31
+ colour = "green" if lbl == pr else "red"
32
+ ax.set_title(title, fontsize=16, color=colour)
33
+
34
+ fig.tight_layout()
35
+ logger.report_matplotlib_figure("sample_predictions", "test", fig, iteration=0)
36
+ plt.close(fig)
37
+
38
+ def plot_cfm(labels, preds, logger, class_names, num_classes):
39
+ cfm = confusion_matrix(labels, preds, labels=list(range(num_classes)))
40
+ cfm_norm = cfm/cfm.sum(axis=1, keepdims=True)
41
+ cfm_norm = np.nan_to_num(cfm_norm)
42
+
43
+ fig, ax = plt.subplots(figsize=(16, 16))
44
+ im = ax.imshow(cfm_norm, interpolation="nearest", cmap="Blues")
45
+ cbar = fig.colorbar(im, ax)
46
+ cbar.ax.set_ylabel("Fraction of sample", rotation=90)
47
+ fig.colorbar(im, ax)
48
+ ax.set_xticks(range(num_classes))
49
+ ax.set_yticks(range(num_classes))
50
+ ax.set_xticklabels(class_names, rotation=90, fontsize=8)
51
+ ax.set_yticklabels(class_names, fontsize=8)
52
+ ax.set_xlabel("Predicted")
53
+ ax.set_ylabel("Ground Truth")
54
+ ax.set_title("Confusion matrix (Normalized)")
55
+
56
+ threshold = cfm_norm.max() / 2.0
57
+ for i in range(num_classes):
58
+ for j in range(num_classes):
59
+ value = cfm_norm[i, j]
60
+ if value == 0:
61
+ continue
62
+ ax.text(j, i, f"{value:.2f}", ha="center", va="center",
63
+ fontsize=5, color="white" if value > threshold else "black")
64
+
65
+ fig.tight_layout()
66
+ logger.report_matplotlib_figure(title="normalized_confusion_matrix", series="test", figure=fig, iteration=0)
67
+ plt.close(fig)
68
+
69
+ cfm_errors = cfm.copy()
70
+ np.fill_diagonal(cfm_errors, 0)
71
+ if cfm_errors.max() > 0:
72
+ fig_err, ax_err = plt.subplots(figsize=(18, 18))
73
+ im_err = ax_err.imshow(cfm_errors, interpolation="nearest", cmap=plt.cm.Blues)
74
+ cbar_err = fig_err.colorbar(im_err, ax=ax_err)
75
+ cbar_err.ax.set_ylabel("Number of misclassified samples", rotation=90)
76
+ ax_err.set_title("Confusion matrix (errors only)")
77
+ ax_err.set_xlabel("Predicted")
78
+ ax_err.set_ylabel("Ground Truth")
79
+ ax_err.set_xticks(np.arange(len(class_names)))
80
+ ax_err.set_yticks(np.arange(len(class_names)))
81
+ ax_err.set_xticklabels(class_names, rotation=90, fontsize=8)
82
+ ax_err.set_yticklabels(class_names, fontsize=8)
83
+
84
+ threshold = cfm_errors.max() / 2.0
85
+ for i in range(num_classes):
86
+ for j in range(num_classes):
87
+ value = cfm_errors[i, j]
88
+ if value == 0:
89
+ continue
90
+ ax_err.text(j, i, str(value), ha="center", va="center",
91
+ fontsize=5, color="white" if value > threshold else "black")
92
+
93
+ fig_err.tight_layout()
94
+ logger.report_matplotlib_figure(title="errors_only_confusion_matrix", series="test", figure=fig_err, iteration=0)
95
+ plt.close(fig_err)
96
+ else:
97
+ print("No misclassifications")
98
+
99
+
100
+