Spaces:
Sleeping
Sleeping
Upload 40 files
Browse files- .gitignore +249 -0
- __pycache__/data_prep.cpython-311.pyc +0 -0
- app.py +41 -0
- data/plant_village/DScomp380___plant_village/default/0.0.0/5ce680f815ea9fab7b6f8346ae4c71e7099696a5/cache-4615ab977727fc47.arrow +3 -0
- data/plant_village/DScomp380___plant_village/default/0.0.0/5ce680f815ea9fab7b6f8346ae4c71e7099696a5/cache-64f7f66e875a2297.arrow +3 -0
- data/plant_village/DScomp380___plant_village/default/0.0.0/5ce680f815ea9fab7b6f8346ae4c71e7099696a5/cache-a711edc7192ef3fb.arrow +3 -0
- data/plant_village/DScomp380___plant_village/default/0.0.0/5ce680f815ea9fab7b6f8346ae4c71e7099696a5/cache-f446c767d80f0d9a.arrow +3 -0
- data/plant_village/DScomp380___plant_village/default/0.0.0/5ce680f815ea9fab7b6f8346ae4c71e7099696a5/dataset_info.json +1 -0
- data/plant_village/DScomp380___plant_village/default/0.0.0/5ce680f815ea9fab7b6f8346ae4c71e7099696a5/plant_village-train-00000-of-00002.arrow +3 -0
- data/plant_village/DScomp380___plant_village/default/0.0.0/5ce680f815ea9fab7b6f8346ae4c71e7099696a5/plant_village-train-00001-of-00002.arrow +3 -0
- models/__pycache__/model.cpython-311.pyc +0 -0
- models/__pycache__/model.cpython-313.pyc +0 -0
- models/model.py +69 -0
- requirements.txt +0 -0
- saved_models/plant_cnn.pt +3 -0
- tabs/__pycache__/batch_processing.cpython-311.pyc +0 -0
- tabs/__pycache__/single_prediction.cpython-311.pyc +0 -0
- tabs/batch_processing.py +46 -0
- tabs/single_prediction.py +77 -0
- ui_text/about.md +25 -0
- ui_text/class_names.json +41 -0
- ui_text/disease_info.json +9 -0
- ui_text/examples/Apple___Apple_scab.jpg +0 -0
- ui_text/examples/Soybean___healthy.jpg +0 -0
- ui_text/examples/Tomato___Bacterial_spot.jpg +0 -0
- ui_text/examples/Tomato___Septoria_leaf_spot.jpg +0 -0
- ui_text/examples/Tomato___Tomato_Yellow_Leaf_Curl_Virus.jpg +0 -0
- ui_text/examples/Tomato___healthy.jpg +0 -0
- ui_text/intro.md +5 -0
- utils/__pycache__/chart_vis.cpython-311.pyc +0 -0
- utils/__pycache__/config.cpython-311.pyc +0 -0
- utils/__pycache__/config.cpython-313.pyc +0 -0
- utils/__pycache__/model_loader.cpython-311.pyc +0 -0
- utils/__pycache__/predictions.cpython-311.pyc +0 -0
- utils/__pycache__/vis.cpython-311.pyc +0 -0
- utils/chart_vis.py +31 -0
- utils/config.py +6 -0
- utils/model_loader.py +51 -0
- utils/predictions.py +58 -0
- utils/vis.py +100 -0
.gitignore
ADDED
|
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.gap_gpu_env/
|
| 2 |
+
saved_models/
|
| 3 |
+
__pychache__/
|
| 4 |
+
*.py[cod]
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
# Created by https://www.toptal.com/developers/gitignore/api/venv,macos,python,visualstudiocode
|
| 8 |
+
# Edit at https://www.toptal.com/developers/gitignore?templates=venv,macos,python,visualstudiocode
|
| 9 |
+
|
| 10 |
+
### macOS ###
|
| 11 |
+
# General
|
| 12 |
+
.DS_Store
|
| 13 |
+
.AppleDouble
|
| 14 |
+
.LSOverride
|
| 15 |
+
|
| 16 |
+
# Icon must end with two \r
|
| 17 |
+
Icon
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# Thumbnails
|
| 21 |
+
._*
|
| 22 |
+
|
| 23 |
+
# Files that might appear in the root of a volume
|
| 24 |
+
.DocumentRevisions-V100
|
| 25 |
+
.fseventsd
|
| 26 |
+
.Spotlight-V100
|
| 27 |
+
.TemporaryItems
|
| 28 |
+
.Trashes
|
| 29 |
+
.VolumeIcon.icns
|
| 30 |
+
.com.apple.timemachine.donotpresent
|
| 31 |
+
|
| 32 |
+
# Directories potentially created on remote AFP share
|
| 33 |
+
.AppleDB
|
| 34 |
+
.AppleDesktop
|
| 35 |
+
Network Trash Folder
|
| 36 |
+
Temporary Items
|
| 37 |
+
.apdisk
|
| 38 |
+
|
| 39 |
+
### macOS Patch ###
|
| 40 |
+
# iCloud generated files
|
| 41 |
+
*.icloud
|
| 42 |
+
|
| 43 |
+
### Python ###
|
| 44 |
+
# Byte-compiled / optimized / DLL files
|
| 45 |
+
__pycache__/
|
| 46 |
+
*.py[cod]
|
| 47 |
+
*$py.class
|
| 48 |
+
|
| 49 |
+
# C extensions
|
| 50 |
+
*.so
|
| 51 |
+
|
| 52 |
+
# Distribution / packaging
|
| 53 |
+
.Python
|
| 54 |
+
build/
|
| 55 |
+
develop-eggs/
|
| 56 |
+
dist/
|
| 57 |
+
downloads/
|
| 58 |
+
eggs/
|
| 59 |
+
.eggs/
|
| 60 |
+
lib/
|
| 61 |
+
lib64/
|
| 62 |
+
parts/
|
| 63 |
+
sdist/
|
| 64 |
+
var/
|
| 65 |
+
wheels/
|
| 66 |
+
share/python-wheels/
|
| 67 |
+
*.egg-info/
|
| 68 |
+
.installed.cfg
|
| 69 |
+
*.egg
|
| 70 |
+
MANIFEST
|
| 71 |
+
|
| 72 |
+
# PyInstaller
|
| 73 |
+
# Usually these files are written by a python script from a template
|
| 74 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 75 |
+
*.manifest
|
| 76 |
+
*.spec
|
| 77 |
+
|
| 78 |
+
# Installer logs
|
| 79 |
+
pip-log.txt
|
| 80 |
+
pip-delete-this-directory.txt
|
| 81 |
+
|
| 82 |
+
# Unit test / coverage reports
|
| 83 |
+
htmlcov/
|
| 84 |
+
.tox/
|
| 85 |
+
.nox/
|
| 86 |
+
.coverage
|
| 87 |
+
.coverage.*
|
| 88 |
+
.cache
|
| 89 |
+
nosetests.xml
|
| 90 |
+
coverage.xml
|
| 91 |
+
*.cover
|
| 92 |
+
*.py,cover
|
| 93 |
+
.hypothesis/
|
| 94 |
+
.pytest_cache/
|
| 95 |
+
cover/
|
| 96 |
+
|
| 97 |
+
# Translations
|
| 98 |
+
*.mo
|
| 99 |
+
*.pot
|
| 100 |
+
|
| 101 |
+
# Django stuff:
|
| 102 |
+
*.log
|
| 103 |
+
local_settings.py
|
| 104 |
+
db.sqlite3
|
| 105 |
+
db.sqlite3-journal
|
| 106 |
+
|
| 107 |
+
# Flask stuff:
|
| 108 |
+
instance/
|
| 109 |
+
.webassets-cache
|
| 110 |
+
|
| 111 |
+
# Scrapy stuff:
|
| 112 |
+
.scrapy
|
| 113 |
+
|
| 114 |
+
# Sphinx documentation
|
| 115 |
+
docs/_build/
|
| 116 |
+
|
| 117 |
+
# PyBuilder
|
| 118 |
+
.pybuilder/
|
| 119 |
+
target/
|
| 120 |
+
|
| 121 |
+
# Jupyter Notebook
|
| 122 |
+
.ipynb_checkpoints
|
| 123 |
+
|
| 124 |
+
# IPython
|
| 125 |
+
profile_default/
|
| 126 |
+
ipython_config.py
|
| 127 |
+
|
| 128 |
+
# pyenv
|
| 129 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 130 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 131 |
+
# .python-version
|
| 132 |
+
|
| 133 |
+
# pipenv
|
| 134 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 135 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 136 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 137 |
+
# install all needed dependencies.
|
| 138 |
+
#Pipfile.lock
|
| 139 |
+
|
| 140 |
+
# poetry
|
| 141 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 142 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 143 |
+
# commonly ignored for libraries.
|
| 144 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 145 |
+
#poetry.lock
|
| 146 |
+
|
| 147 |
+
# pdm
|
| 148 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 149 |
+
#pdm.lock
|
| 150 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
| 151 |
+
# in version control.
|
| 152 |
+
# https://pdm.fming.dev/#use-with-ide
|
| 153 |
+
.pdm.toml
|
| 154 |
+
|
| 155 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 156 |
+
__pypackages__/
|
| 157 |
+
|
| 158 |
+
# Celery stuff
|
| 159 |
+
celerybeat-schedule
|
| 160 |
+
celerybeat.pid
|
| 161 |
+
|
| 162 |
+
# SageMath parsed files
|
| 163 |
+
*.sage.py
|
| 164 |
+
|
| 165 |
+
# Environments
|
| 166 |
+
.env
|
| 167 |
+
.venv
|
| 168 |
+
env/
|
| 169 |
+
venv/
|
| 170 |
+
ENV/
|
| 171 |
+
env.bak/
|
| 172 |
+
venv.bak/
|
| 173 |
+
|
| 174 |
+
# Spyder project settings
|
| 175 |
+
.spyderproject
|
| 176 |
+
.spyproject
|
| 177 |
+
|
| 178 |
+
# Rope project settings
|
| 179 |
+
.ropeproject
|
| 180 |
+
|
| 181 |
+
# mkdocs documentation
|
| 182 |
+
/site
|
| 183 |
+
|
| 184 |
+
# mypy
|
| 185 |
+
.mypy_cache/
|
| 186 |
+
.dmypy.json
|
| 187 |
+
dmypy.json
|
| 188 |
+
|
| 189 |
+
# Pyre type checker
|
| 190 |
+
.pyre/
|
| 191 |
+
|
| 192 |
+
# pytype static type analyzer
|
| 193 |
+
.pytype/
|
| 194 |
+
|
| 195 |
+
# Cython debug symbols
|
| 196 |
+
cython_debug/
|
| 197 |
+
|
| 198 |
+
# PyCharm
|
| 199 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 200 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 201 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 202 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 203 |
+
#.idea/
|
| 204 |
+
|
| 205 |
+
### Python Patch ###
|
| 206 |
+
# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
|
| 207 |
+
poetry.toml
|
| 208 |
+
|
| 209 |
+
# ruff
|
| 210 |
+
.ruff_cache/
|
| 211 |
+
|
| 212 |
+
# LSP config files
|
| 213 |
+
pyrightconfig.json
|
| 214 |
+
|
| 215 |
+
### venv ###
|
| 216 |
+
# Virtualenv
|
| 217 |
+
# http://iamzed.com/2009/05/07/a-primer-on-virtualenv/
|
| 218 |
+
[Bb]in
|
| 219 |
+
[Ii]nclude
|
| 220 |
+
[Ll]ib
|
| 221 |
+
[Ll]ib64
|
| 222 |
+
[Ll]ocal
|
| 223 |
+
[Ss]cripts
|
| 224 |
+
pyvenv.cfg
|
| 225 |
+
pip-selfcheck.json
|
| 226 |
+
|
| 227 |
+
### VisualStudioCode ###
|
| 228 |
+
.vscode/*
|
| 229 |
+
!.vscode/settings.json
|
| 230 |
+
!.vscode/tasks.json
|
| 231 |
+
!.vscode/launch.json
|
| 232 |
+
!.vscode/extensions.json
|
| 233 |
+
!.vscode/*.code-snippets
|
| 234 |
+
|
| 235 |
+
# Local History for Visual Studio Code
|
| 236 |
+
.history/
|
| 237 |
+
|
| 238 |
+
# Built Visual Studio Code Extensions
|
| 239 |
+
*.vsix
|
| 240 |
+
|
| 241 |
+
### VisualStudioCode Patch ###
|
| 242 |
+
# Ignore all local history of files
|
| 243 |
+
.history
|
| 244 |
+
.ionide
|
| 245 |
+
|
| 246 |
+
# End of https://www.toptal.com/developers/gitignore/api/venv,macos,python,visualstudiocode
|
| 247 |
+
|
| 248 |
+
data/
|
| 249 |
+
*.arrow
|
__pycache__/data_prep.cpython-311.pyc
ADDED
|
Binary file (4.62 kB). View file
|
|
|
app.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from utils.model_loader import load_model_and_config, load_ui_text
|
| 3 |
+
from tabs.single_prediction import create_single_prediction_tab
|
| 4 |
+
from tabs.batch_processing import create_batch_processing_tab
|
| 5 |
+
|
| 6 |
+
# loadign model
|
| 7 |
+
config = load_model_and_config()
|
| 8 |
+
intro_md, about_md = load_ui_text()
|
| 9 |
+
|
| 10 |
+
model = config['model']
|
| 11 |
+
class_names = config['class_names']
|
| 12 |
+
disease_db = config['disease_db']
|
| 13 |
+
device = config['device']
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# Custom CSS
|
| 17 |
+
custom_css = """
|
| 18 |
+
.gradio-container {
|
| 19 |
+
font-family: 'Arial', sans-serif;
|
| 20 |
+
}
|
| 21 |
+
.output-class {
|
| 22 |
+
font-size: 16px;
|
| 23 |
+
}
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
# Create Gradio Interface
|
| 27 |
+
with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
|
| 28 |
+
|
| 29 |
+
gr.Markdown(intro_md)
|
| 30 |
+
|
| 31 |
+
with gr.Tab("Single Image Prediction"):
|
| 32 |
+
create_single_prediction_tab(model, class_names, disease_db, device)
|
| 33 |
+
|
| 34 |
+
with gr.Tab("Batch Processing"):
|
| 35 |
+
create_batch_processing_tab(model, class_names, device)
|
| 36 |
+
|
| 37 |
+
with gr.Tab("About"):
|
| 38 |
+
gr.Markdown(about_md)
|
| 39 |
+
|
| 40 |
+
if __name__ == "__main__":
|
| 41 |
+
demo.launch(share=False)
|
data/plant_village/DScomp380___plant_village/default/0.0.0/5ce680f815ea9fab7b6f8346ae4c71e7099696a5/cache-4615ab977727fc47.arrow
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e87aa03a7d726658b12bfcfc4c8faea5fc24dc6cf75123d005b1443c6a7ca79c
|
| 3 |
+
size 68144
|
data/plant_village/DScomp380___plant_village/default/0.0.0/5ce680f815ea9fab7b6f8346ae4c71e7099696a5/cache-64f7f66e875a2297.arrow
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:94c680bcbb0a710782f6133a110f1eecdbdade266786f4726cdaa94b0189182c
|
| 3 |
+
size 135832
|
data/plant_village/DScomp380___plant_village/default/0.0.0/5ce680f815ea9fab7b6f8346ae4c71e7099696a5/cache-a711edc7192ef3fb.arrow
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:52bd1f2a2bff2a96a8ea9d96739bb7814a193bc4c13a9894f4768c5ef335da3a
|
| 3 |
+
size 316416
|
data/plant_village/DScomp380___plant_village/default/0.0.0/5ce680f815ea9fab7b6f8346ae4c71e7099696a5/cache-f446c767d80f0d9a.arrow
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:60305644da641ec434b5511f370be2e75bea01809259c8d10ffb624dfa57c15d
|
| 3 |
+
size 68136
|
data/plant_village/DScomp380___plant_village/default/0.0.0/5ce680f815ea9fab7b6f8346ae4c71e7099696a5/dataset_info.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"description": "", "citation": "", "homepage": "", "license": "", "features": {"image": {"_type": "Image"}, "label": {"names": ["Apple___Apple_scab", "Apple___Black_rot", "Apple___Cedar_apple_rust", "Apple___healthy", "Background_without_leaves", "Blueberry___healthy", "Cherry___Powdery_mildew", "Cherry___healthy", "Corn___Cercospora_leaf_spot Gray_leaf_spot", "Corn___Common_rust", "Corn___Northern_Leaf_Blight", "Corn___healthy", "Grape___Black_rot", "Grape___Esca_(Black_Measles)", "Grape___Leaf_blight_(Isariopsis_Leaf_Spot)", "Grape___healthy", "Orange___Haunglongbing_(Citrus_greening)", "Peach___Bacterial_spot", "Peach___healthy", "Pepper,_bell___Bacterial_spot", "Pepper,_bell___healthy", "Potato___Early_blight", "Potato___Late_blight", "Potato___healthy", "Raspberry___healthy", "Soybean___healthy", "Squash___Powdery_mildew", "Strawberry___Leaf_scorch", "Strawberry___healthy", "Tomato___Bacterial_spot", "Tomato___Early_blight", "Tomato___Late_blight", "Tomato___Leaf_Mold", "Tomato___Septoria_leaf_spot", "Tomato___Spider_mites Two-spotted_spider_mite", "Tomato___Target_Spot", "Tomato___Tomato_Yellow_Leaf_Curl_Virus", "Tomato___Tomato_mosaic_virus", "Tomato___healthy"], "_type": "ClassLabel"}}, "builder_name": "parquet", "dataset_name": "plant_village", "config_name": "default", "version": {"version_str": "0.0.0", "major": 0, "minor": 0, "patch": 0}, "splits": {"train": {"name": "train", "num_bytes": 863151201, "num_examples": 55447, "shard_lengths": [33224, 22223], "dataset_name": "plant_village"}}, "download_checksums": {"hf://datasets/DScomp380/plant_village@5ce680f815ea9fab7b6f8346ae4c71e7099696a5/data/train-00000-of-00002.parquet": {"num_bytes": 400759198, "checksum": null}, "hf://datasets/DScomp380/plant_village@5ce680f815ea9fab7b6f8346ae4c71e7099696a5/data/train-00001-of-00002.parquet": {"num_bytes": 459968278, "checksum": null}}, "download_size": 860727476, "dataset_size": 863151201, "size_in_bytes": 1723878677}
|
data/plant_village/DScomp380___plant_village/default/0.0.0/5ce680f815ea9fab7b6f8346ae4c71e7099696a5/plant_village-train-00000-of-00002.arrow
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e9139be0a0731bd3e741d83fcc7ca3f8a892fffb929f8edba4f365537aa67f03
|
| 3 |
+
size 500987424
|
data/plant_village/DScomp380___plant_village/default/0.0.0/5ce680f815ea9fab7b6f8346ae4c71e7099696a5/plant_village-train-00001-of-00002.arrow
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a0fcda5b77e896d97a0313c78271a9d6ae632fe61db304a8b97dfe85dff9197d
|
| 3 |
+
size 362342232
|
models/__pycache__/model.cpython-311.pyc
ADDED
|
Binary file (4.21 kB). View file
|
|
|
models/__pycache__/model.cpython-313.pyc
ADDED
|
Binary file (3.78 kB). View file
|
|
|
models/model.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
class ConvBlock(nn.Module):
|
| 5 |
+
def __init__(self, in_channels:int, out_channels:int) :
|
| 6 |
+
super().__init__()
|
| 7 |
+
|
| 8 |
+
# first convolutional layer
|
| 9 |
+
self.conv_1 = nn.Conv2d(in_channels, out_channels,
|
| 10 |
+
kernel_size=3, stride=1,
|
| 11 |
+
padding=1, bias=False)
|
| 12 |
+
self.batch_norm_1 = nn.BatchNorm2d(num_features=out_channels)
|
| 13 |
+
|
| 14 |
+
# second convolutional layer
|
| 15 |
+
self.conv_2 = nn.Conv2d(out_channels, out_channels,
|
| 16 |
+
kernel_size=3, stride=1,
|
| 17 |
+
padding=1, bias=False)
|
| 18 |
+
self.batch_norm_2 = nn.BatchNorm2d(num_features=out_channels)
|
| 19 |
+
|
| 20 |
+
self.activation = nn.ReLU(inplace=True)
|
| 21 |
+
|
| 22 |
+
def forward(self, x):
|
| 23 |
+
# basic conv -> bn -> relu forward pass
|
| 24 |
+
output = self.activation(self.batch_norm_1(self.conv_1(x)))
|
| 25 |
+
output = self.activation(self.batch_norm_2(self.conv_2(output)))
|
| 26 |
+
return output
|
| 27 |
+
|
| 28 |
+
class PlantCNN(nn.Module):
|
| 29 |
+
def __init__(self, num_classes:int, channels, dropout: float):
|
| 30 |
+
super().__init__()
|
| 31 |
+
|
| 32 |
+
# entry block to map RGB -> 64 channels
|
| 33 |
+
first_c = channels[0]
|
| 34 |
+
self.input_block = nn.Sequential(
|
| 35 |
+
nn.Conv2d(in_channels=3, out_channels=first_c,
|
| 36 |
+
kernel_size=3, stride=1,
|
| 37 |
+
padding=1, bias=False),
|
| 38 |
+
nn.BatchNorm2d(num_features=first_c),
|
| 39 |
+
nn.ReLU(inplace=True)
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
self.stages = nn.ModuleList()
|
| 43 |
+
in_c = first_c
|
| 44 |
+
for c in channels:
|
| 45 |
+
stage = nn.Sequential(
|
| 46 |
+
ConvBlock(in_c,c),
|
| 47 |
+
ConvBlock(c,c),
|
| 48 |
+
nn.MaxPool2d(kernel_size=2)
|
| 49 |
+
)
|
| 50 |
+
self.stages.append(stage)
|
| 51 |
+
in_c = c
|
| 52 |
+
|
| 53 |
+
# final pooling + classifer
|
| 54 |
+
self.pool = nn.AdaptiveAvgPool2d(1)
|
| 55 |
+
self.dropout = nn.Dropout(dropout)
|
| 56 |
+
self.fc = nn.Linear(channels[-1], num_classes) #change for app.py
|
| 57 |
+
|
| 58 |
+
def forward(self, x):
|
| 59 |
+
output = self.input_block(x)
|
| 60 |
+
# pass through each stage in order
|
| 61 |
+
for stage in self.stages:
|
| 62 |
+
output = stage(output)
|
| 63 |
+
|
| 64 |
+
# pool to (batch, 512) then flatten
|
| 65 |
+
output = torch.flatten(self.pool(output), 1)
|
| 66 |
+
output = self.fc(self.dropout(output))
|
| 67 |
+
return output
|
| 68 |
+
|
| 69 |
+
|
requirements.txt
ADDED
|
Binary file (218 Bytes). View file
|
|
|
saved_models/plant_cnn.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4ec7e8ff4511d002f97f9c2805e02a1a9b2900a54c9c11ed58e5d07e3e82fb9e
|
| 3 |
+
size 44127290
|
tabs/__pycache__/batch_processing.cpython-311.pyc
ADDED
|
Binary file (2.38 kB). View file
|
|
|
tabs/__pycache__/single_prediction.cpython-311.pyc
ADDED
|
Binary file (4.32 kB). View file
|
|
|
tabs/batch_processing.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from utils.predictions import predict_batch_visual
|
| 3 |
+
|
| 4 |
+
# this was written with the help of AI
|
| 5 |
+
def create_batch_processing_tab(model, class_names, device):
|
| 6 |
+
"""Create the batch processing tab"""
|
| 7 |
+
|
| 8 |
+
def batch_predict_wrapper(files, progress=gr.Progress()):
|
| 9 |
+
if not files:
|
| 10 |
+
return None, " No files uploaded! Please select images to process."
|
| 11 |
+
|
| 12 |
+
progress(0, desc="Starting batch processing...")
|
| 13 |
+
gallery, results = predict_batch_visual(files, model, class_names, device, progress)
|
| 14 |
+
return gallery, results
|
| 15 |
+
|
| 16 |
+
gr.Markdown("""
|
| 17 |
+
### Upload multiple images for batch prediction
|
| 18 |
+
|
| 19 |
+
Upload several plant leaf images at once to get predictions for all of them.
|
| 20 |
+
Results will show each image with its prediction and confidence score.
|
| 21 |
+
""")
|
| 22 |
+
|
| 23 |
+
batch_input = gr.File(
|
| 24 |
+
file_count="multiple",
|
| 25 |
+
label="Upload Multiple Images",
|
| 26 |
+
file_types=["image"]
|
| 27 |
+
)
|
| 28 |
+
batch_btn = gr.Button("Process All Images", variant="primary", size="lg")
|
| 29 |
+
|
| 30 |
+
with gr.Row():
|
| 31 |
+
batch_gallery = gr.Gallery(
|
| 32 |
+
label="Processed Images",
|
| 33 |
+
columns=3,
|
| 34 |
+
height="auto",
|
| 35 |
+
object_fit="contain"
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
batch_output = gr.Markdown(label="Detailed Results")
|
| 39 |
+
|
| 40 |
+
batch_btn.click(
|
| 41 |
+
fn=batch_predict_wrapper,
|
| 42 |
+
inputs=batch_input,
|
| 43 |
+
outputs=[batch_gallery, batch_output]
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
return batch_input, batch_btn, batch_gallery, batch_output
|
tabs/single_prediction.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from utils.predictions import predict_single_image, get_disease_info
|
| 3 |
+
from utils.chart_vis import create_prediction_plot
|
| 4 |
+
|
| 5 |
+
# this was written with the help of AI
|
| 6 |
+
def create_single_prediction_tab(model, class_names, disease_db, device):
|
| 7 |
+
"""Create the single image prediction tab"""
|
| 8 |
+
|
| 9 |
+
def predict_with_visualization(image, show_top_n):
|
| 10 |
+
"""Prediction function with all outputs"""
|
| 11 |
+
if image is None:
|
| 12 |
+
return None, "Please upload an image", None
|
| 13 |
+
|
| 14 |
+
# Make prediction
|
| 15 |
+
top_preds = predict_single_image(image, model, class_names, device, show_top_n)
|
| 16 |
+
|
| 17 |
+
# Create visualization
|
| 18 |
+
plot = create_prediction_plot(top_preds)
|
| 19 |
+
|
| 20 |
+
# Get disease info
|
| 21 |
+
top_disease = top_preds[0][0]
|
| 22 |
+
confidence = top_preds[0][1]
|
| 23 |
+
|
| 24 |
+
info_text = f"## Top Prediction: {top_disease}\n"
|
| 25 |
+
info_text += f"**Confidence:** {confidence:.2%}\n\n"
|
| 26 |
+
info_text += f"{get_disease_info(top_disease, disease_db)}\n\n"
|
| 27 |
+
|
| 28 |
+
if confidence < 0.5:
|
| 29 |
+
info_text += "**Note:** Low confidence. Consider expert verification."
|
| 30 |
+
|
| 31 |
+
# Results dictionary
|
| 32 |
+
results_dict = {label: round(float(prob), 4) for label, prob in top_preds}
|
| 33 |
+
|
| 34 |
+
return plot, info_text, results_dict
|
| 35 |
+
|
| 36 |
+
examples = [
|
| 37 |
+
["ui_text/examples/Apple___Apple_scab.jpg"],
|
| 38 |
+
["ui_text/examples/Tomato___healthy.jpg"],
|
| 39 |
+
["ui_text/examples/Tomato___Bacterial_spot.jpg"],
|
| 40 |
+
["ui_text/examples/Tomato___Tomato_Yellow_Leaf_Curl_Virus.jpg"],
|
| 41 |
+
["ui_text/examples/Tomato___Septoria_leaf_spot.jpg"],
|
| 42 |
+
["ui_text/examples/Soybean___healthy.jpg"]
|
| 43 |
+
]
|
| 44 |
+
|
| 45 |
+
with gr.Row():
|
| 46 |
+
with gr.Column(scale=1):
|
| 47 |
+
input_image = gr.Image(type="pil", label="Upload Plant Leaf Image")
|
| 48 |
+
top_n_slider = gr.Slider(
|
| 49 |
+
minimum=3,
|
| 50 |
+
maximum=15,
|
| 51 |
+
value=10,
|
| 52 |
+
step=1,
|
| 53 |
+
label="Number of top predictions to show"
|
| 54 |
+
)
|
| 55 |
+
predict_btn = gr.Button("Analyze Disease", variant="primary")
|
| 56 |
+
|
| 57 |
+
gr.Markdown("### Example Images")
|
| 58 |
+
gr.Examples(
|
| 59 |
+
examples=examples,
|
| 60 |
+
inputs=input_image,
|
| 61 |
+
label="Click an example to try it out",
|
| 62 |
+
cache_examples=False
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
with gr.Column(scale=1):
|
| 66 |
+
output_plot = gr.Plot(label="Prediction Confidence Chart")
|
| 67 |
+
output_info = gr.Markdown(label="Disease Information")
|
| 68 |
+
output_label = gr.Label(label="Detailed Predictions", num_top_classes=10)
|
| 69 |
+
|
| 70 |
+
# Connect button
|
| 71 |
+
predict_btn.click(
|
| 72 |
+
fn=predict_with_visualization,
|
| 73 |
+
inputs=[input_image, top_n_slider],
|
| 74 |
+
outputs=[output_plot, output_info, output_label]
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
return input_image, top_n_slider, predict_btn
|
ui_text/about.md
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
## Model Information
|
| 3 |
+
|
| 4 |
+
**Architecture:** Custom PlantCNN
|
| 5 |
+
- **Channels:** [96, 192, 384, 768]
|
| 6 |
+
- **Number of Classes:** 39
|
| 7 |
+
- **Input Size:** 224x224 pixels
|
| 8 |
+
- **Framework:** PyTorch
|
| 9 |
+
|
| 10 |
+
## Features
|
| 11 |
+
|
| 12 |
+
- Real-time disease detection
|
| 13 |
+
- Confidence visualization with histogram
|
| 14 |
+
- Top-N predictions customization
|
| 15 |
+
- Batch processing support
|
| 16 |
+
- Pre-loaded example gallery
|
| 17 |
+
- Disease information and treatment suggestions
|
| 18 |
+
|
| 19 |
+
## How to Use
|
| 20 |
+
|
| 21 |
+
1. Upload a clear image of a plant leaf
|
| 22 |
+
2. Adjust the number of predictions if needed
|
| 23 |
+
3. Click "Analyze Disease" to get results
|
| 24 |
+
4. Review the confidence chart and recommendations
|
| 25 |
+
|
ui_text/class_names.json
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
"Apple___Apple_scab",
|
| 3 |
+
"Apple___Black_rot",
|
| 4 |
+
"Apple___Cedar_apple_rust",
|
| 5 |
+
"Apple___healthy",
|
| 6 |
+
"Background_without_leaves",
|
| 7 |
+
"Blueberry___healthy",
|
| 8 |
+
"Cherry___Powdery_mildew",
|
| 9 |
+
"Cherry___healthy",
|
| 10 |
+
"Corn___Cercospora_leaf_spot_Gray_leaf_spot",
|
| 11 |
+
"Corn___Common_rust",
|
| 12 |
+
"Corn___Northern_Leaf_Blight",
|
| 13 |
+
"Corn___healthy",
|
| 14 |
+
"Grape___Black_rot",
|
| 15 |
+
"Grape___Esca_(Black_Measles)",
|
| 16 |
+
"Grape___Leaf_blight_(Isariopsis_Leaf_Spot)",
|
| 17 |
+
"Grape___healthy",
|
| 18 |
+
"Orange___Haunglongbing_(Citrus_greening)",
|
| 19 |
+
"Peach___Bacterial_spot",
|
| 20 |
+
"Peach___healthy",
|
| 21 |
+
"Pepper,_bell___Bacterial_spot",
|
| 22 |
+
"Pepper,_bell___healthy",
|
| 23 |
+
"Potato___Early_blight",
|
| 24 |
+
"Potato___Late_blight",
|
| 25 |
+
"Potato___healthy",
|
| 26 |
+
"Raspberry___healthy",
|
| 27 |
+
"Soybean___healthy",
|
| 28 |
+
"Squash___Powdery_mildew",
|
| 29 |
+
"Strawberry___Leaf_scorch",
|
| 30 |
+
"Strawberry___healthy",
|
| 31 |
+
"Tomato___Bacterial_spot",
|
| 32 |
+
"Tomato___Early_blight",
|
| 33 |
+
"Tomato___Late_blight",
|
| 34 |
+
"Tomato___Leaf_Mold",
|
| 35 |
+
"Tomato___Septoria_leaf_spot",
|
| 36 |
+
"Tomato___Spider_mites_Two-spotted_spider_mite",
|
| 37 |
+
"Tomato___Target_Spot",
|
| 38 |
+
"Tomato___Tomato_Yellow_Leaf_Curl_Virus",
|
| 39 |
+
"Tomato___Tomato_mosaic_virus",
|
| 40 |
+
"Tomato___healthy"
|
| 41 |
+
]
|
ui_text/disease_info.json
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"healthy": "✅ No disease detected! The plant appears healthy.",
|
| 3 |
+
"scab": "Apple Scab: A fungal disease causing dark, scabby lesions. Treatment: Apply fungicides and remove infected leaves.",
|
| 4 |
+
"rust": "Rust: Fungal infection with orange/brown pustules. Treatment: Use fungicides and improve air circulation.",
|
| 5 |
+
"spot": "Leaf Spot: Bacterial or fungal spots on leaves. Treatment: Remove infected parts and apply appropriate treatment.",
|
| 6 |
+
"blight": "Blight: A severe disease causing rapid plant death. Treatment: Remove infected plants and use resistant varieties.",
|
| 7 |
+
"mold": "Mold: Fungal growth on leaves. Treatment: Improve ventilation and reduce humidity.",
|
| 8 |
+
"virus": "Viral disease: Transmitted by insects. Treatment: Remove infected plants and control insect vectors."
|
| 9 |
+
}
|
ui_text/examples/Apple___Apple_scab.jpg
ADDED
|
ui_text/examples/Soybean___healthy.jpg
ADDED
|
ui_text/examples/Tomato___Bacterial_spot.jpg
ADDED
|
ui_text/examples/Tomato___Septoria_leaf_spot.jpg
ADDED
|
ui_text/examples/Tomato___Tomato_Yellow_Leaf_Curl_Virus.jpg
ADDED
|
ui_text/examples/Tomato___healthy.jpg
ADDED
|
ui_text/intro.md
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
# Plant Disease Detection System
|
| 3 |
+
|
| 4 |
+
Upload an image of a plant leaf to detect diseases using our trained CNN model.
|
| 5 |
+
The model can identify **39 different plant diseases and healthy conditions**.
|
utils/__pycache__/chart_vis.cpython-311.pyc
ADDED
|
Binary file (2.38 kB). View file
|
|
|
utils/__pycache__/config.cpython-311.pyc
ADDED
|
Binary file (639 Bytes). View file
|
|
|
utils/__pycache__/config.cpython-313.pyc
ADDED
|
Binary file (477 Bytes). View file
|
|
|
utils/__pycache__/model_loader.cpython-311.pyc
ADDED
|
Binary file (2.95 kB). View file
|
|
|
utils/__pycache__/predictions.cpython-311.pyc
ADDED
|
Binary file (7.79 kB). View file
|
|
|
utils/__pycache__/vis.cpython-311.pyc
ADDED
|
Binary file (7.07 kB). View file
|
|
|
utils/chart_vis.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import matplotlib.pyplot as plt
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
def create_prediction_plot(top_preds):
|
| 5 |
+
"""Create a horizontal bar chart of predictions"""
|
| 6 |
+
labels = [label for label, _ in top_preds]
|
| 7 |
+
probs = [prob for _, prob in top_preds]
|
| 8 |
+
|
| 9 |
+
# figure
|
| 10 |
+
fig, ax = plt.subplots(figsize=(10, 6))
|
| 11 |
+
|
| 12 |
+
# horizontal bar chart
|
| 13 |
+
y_pos = np.arange(len(labels))
|
| 14 |
+
colors = plt.cm.RdYlGn(np.array(probs)) # Color based on confidence AI idea
|
| 15 |
+
|
| 16 |
+
ax.barh(y_pos, probs, color=colors, alpha=0.8)
|
| 17 |
+
ax.set_yticks(y_pos)
|
| 18 |
+
ax.set_yticklabels(labels)
|
| 19 |
+
ax.invert_yaxis() # Top prediction at the top
|
| 20 |
+
ax.set_xlabel('Confidence Score', fontsize=12)
|
| 21 |
+
ax.set_title('Top Disease Predictions', fontsize=14, fontweight='bold')
|
| 22 |
+
ax.set_xlim([0, 1])
|
| 23 |
+
|
| 24 |
+
# Add value labels on bars
|
| 25 |
+
for i, (label, prob) in enumerate(zip(labels, probs)):
|
| 26 |
+
ax.text(prob + 0.01, i, f'{prob:.3f}',
|
| 27 |
+
va='center', fontsize=10)
|
| 28 |
+
|
| 29 |
+
plt.tight_layout()
|
| 30 |
+
return fig
|
| 31 |
+
|
utils/config.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import yaml
|
| 2 |
+
|
| 3 |
+
# helper function to load configs from yaml file
|
| 4 |
+
def load_config(path="config.yaml"):
|
| 5 |
+
with open(path, "r") as f:
|
| 6 |
+
return yaml.safe_load(f)
|
utils/model_loader.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
from models.model import PlantCNN
|
| 5 |
+
|
| 6 |
+
def load_model_and_config():
|
| 7 |
+
"""Load the trained model and all configuration files"""
|
| 8 |
+
|
| 9 |
+
# Paths
|
| 10 |
+
MODEL_PATH = "saved_models/plant_cnn.pt"
|
| 11 |
+
CLASS_NAMES_PATH = "ui_text/class_names.json"
|
| 12 |
+
|
| 13 |
+
# Config
|
| 14 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 15 |
+
CHANNELS = [96, 192, 384, 768]
|
| 16 |
+
DROPOUT = 0.4
|
| 17 |
+
NUM_CLASSES = 39
|
| 18 |
+
|
| 19 |
+
# Load class names
|
| 20 |
+
with open(CLASS_NAMES_PATH, "r") as f:
|
| 21 |
+
class_names = json.load(f)
|
| 22 |
+
|
| 23 |
+
# Load disease info
|
| 24 |
+
with open("ui_text/disease_info.json", "r", encoding="utf-8") as f:
|
| 25 |
+
disease_db = json.load(f)
|
| 26 |
+
|
| 27 |
+
# Load model
|
| 28 |
+
model = PlantCNN(num_classes=NUM_CLASSES, channels=CHANNELS, dropout=DROPOUT).to(DEVICE)
|
| 29 |
+
if os.path.exists(MODEL_PATH):
|
| 30 |
+
print("Loading trained model weights...")
|
| 31 |
+
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
|
| 32 |
+
model.eval()
|
| 33 |
+
else:
|
| 34 |
+
exit()
|
| 35 |
+
|
| 36 |
+
return {
|
| 37 |
+
'model': model,
|
| 38 |
+
'class_names': class_names,
|
| 39 |
+
'disease_db': disease_db,
|
| 40 |
+
'device': DEVICE
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
def load_ui_text():
|
| 44 |
+
"""Load intro and about markdown files"""
|
| 45 |
+
with open("ui_text/intro.md", "r", encoding="utf-8") as f:
|
| 46 |
+
intro_md = f.read()
|
| 47 |
+
|
| 48 |
+
with open("ui_text/about.md", "r", encoding="utf-8") as f:
|
| 49 |
+
about_md = f.read()
|
| 50 |
+
|
| 51 |
+
return intro_md, about_md
|
utils/predictions.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torchvision.transforms as transforms
|
| 3 |
+
from PIL import Image
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
# this was written with the help of AI
|
| 7 |
+
transform = transforms.Compose([
|
| 8 |
+
transforms.Resize((224, 224)),
|
| 9 |
+
transforms.ToTensor(),
|
| 10 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 11 |
+
])
|
| 12 |
+
|
| 13 |
+
def get_disease_info(disease_name, disease_db):
|
| 14 |
+
"""Get information about a detected disease"""
|
| 15 |
+
disease_lower = disease_name.lower()
|
| 16 |
+
for key, info in disease_db.items():
|
| 17 |
+
if key in disease_lower:
|
| 18 |
+
return info
|
| 19 |
+
return " Disease information not available for this classification."
|
| 20 |
+
|
| 21 |
+
def predict_single_image(image, model, class_names, device, show_top_n=10):
|
| 22 |
+
"""Make prediction on a single image"""
|
| 23 |
+
if image is None:
|
| 24 |
+
return None
|
| 25 |
+
|
| 26 |
+
# Transform and predict
|
| 27 |
+
image_tensor = transform(image).unsqueeze(0).to(device)
|
| 28 |
+
with torch.no_grad():
|
| 29 |
+
outputs = model(image_tensor)
|
| 30 |
+
probs = torch.nn.functional.softmax(outputs, dim=1).cpu().squeeze().numpy()
|
| 31 |
+
|
| 32 |
+
# Get top predictions
|
| 33 |
+
all_preds = sorted(zip(class_names, probs), key=lambda x: x[1], reverse=True)
|
| 34 |
+
top_preds = all_preds[:show_top_n]
|
| 35 |
+
|
| 36 |
+
return top_preds
|
| 37 |
+
|
| 38 |
+
def predict_batch(files, model, class_names, device):
|
| 39 |
+
"""Process multiple images at once"""
|
| 40 |
+
if not files:
|
| 41 |
+
return "No files uploaded"
|
| 42 |
+
|
| 43 |
+
results = []
|
| 44 |
+
for file in files:
|
| 45 |
+
image = Image.open(file.name).convert('RGB')
|
| 46 |
+
image_tensor = transform(image).unsqueeze(0).to(device)
|
| 47 |
+
|
| 48 |
+
with torch.no_grad():
|
| 49 |
+
outputs = model(image_tensor)
|
| 50 |
+
probs = torch.nn.functional.softmax(outputs, dim=1).cpu().squeeze().numpy()
|
| 51 |
+
|
| 52 |
+
top_pred = class_names[probs.argmax()]
|
| 53 |
+
confidence = probs.max()
|
| 54 |
+
|
| 55 |
+
results.append(f"**{os.path.basename(file.name)}**: {top_pred} ({confidence:.2%})")
|
| 56 |
+
|
| 57 |
+
return "\n\n".join(results)
|
| 58 |
+
|
utils/vis.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import matplotlib.pyplot as plt
|
| 3 |
+
from sklearn.metrics import confusion_matrix
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
def to_display_image(img_tensor, mean, std):
|
| 7 |
+
img = img_tensor.cpu().numpy()
|
| 8 |
+
for c in range(3):
|
| 9 |
+
img[c] = img[c]*std[c]+mean[c]
|
| 10 |
+
img = np.clip(img, 0.0, 1.0)
|
| 11 |
+
img = np.transpose(img, (1,2,0))
|
| 12 |
+
return img
|
| 13 |
+
|
| 14 |
+
def visualize_preds(images, labels, preds, logger, class_names, mean, std, num_images):
|
| 15 |
+
num_images = min(num_images, len(images))
|
| 16 |
+
rows = int(np.ceil(num_images/4))
|
| 17 |
+
fig, axs = plt.subplots(rows, 4, figsize=(24, 6*rows))
|
| 18 |
+
axs = axs.flatten()
|
| 19 |
+
|
| 20 |
+
for i, ax in enumerate(axs):
|
| 21 |
+
ax.axis("off")
|
| 22 |
+
if i >= len(images):
|
| 23 |
+
continue
|
| 24 |
+
|
| 25 |
+
img = to_display_image(images[i], mean, std)
|
| 26 |
+
lbl = labels[i]
|
| 27 |
+
pr = preds[i]
|
| 28 |
+
|
| 29 |
+
ax.imshow(img)
|
| 30 |
+
title = f"Label: {class_names[lbl]}\nPrediction: {class_names[pr]}"
|
| 31 |
+
colour = "green" if lbl == pr else "red"
|
| 32 |
+
ax.set_title(title, fontsize=16, color=colour)
|
| 33 |
+
|
| 34 |
+
fig.tight_layout()
|
| 35 |
+
logger.report_matplotlib_figure("sample_predictions", "test", fig, iteration=0)
|
| 36 |
+
plt.close(fig)
|
| 37 |
+
|
| 38 |
+
def plot_cfm(labels, preds, logger, class_names, num_classes):
|
| 39 |
+
cfm = confusion_matrix(labels, preds, labels=list(range(num_classes)))
|
| 40 |
+
cfm_norm = cfm/cfm.sum(axis=1, keepdims=True)
|
| 41 |
+
cfm_norm = np.nan_to_num(cfm_norm)
|
| 42 |
+
|
| 43 |
+
fig, ax = plt.subplots(figsize=(16, 16))
|
| 44 |
+
im = ax.imshow(cfm_norm, interpolation="nearest", cmap="Blues")
|
| 45 |
+
cbar = fig.colorbar(im, ax)
|
| 46 |
+
cbar.ax.set_ylabel("Fraction of sample", rotation=90)
|
| 47 |
+
fig.colorbar(im, ax)
|
| 48 |
+
ax.set_xticks(range(num_classes))
|
| 49 |
+
ax.set_yticks(range(num_classes))
|
| 50 |
+
ax.set_xticklabels(class_names, rotation=90, fontsize=8)
|
| 51 |
+
ax.set_yticklabels(class_names, fontsize=8)
|
| 52 |
+
ax.set_xlabel("Predicted")
|
| 53 |
+
ax.set_ylabel("Ground Truth")
|
| 54 |
+
ax.set_title("Confusion matrix (Normalized)")
|
| 55 |
+
|
| 56 |
+
threshold = cfm_norm.max() / 2.0
|
| 57 |
+
for i in range(num_classes):
|
| 58 |
+
for j in range(num_classes):
|
| 59 |
+
value = cfm_norm[i, j]
|
| 60 |
+
if value == 0:
|
| 61 |
+
continue
|
| 62 |
+
ax.text(j, i, f"{value:.2f}", ha="center", va="center",
|
| 63 |
+
fontsize=5, color="white" if value > threshold else "black")
|
| 64 |
+
|
| 65 |
+
fig.tight_layout()
|
| 66 |
+
logger.report_matplotlib_figure(title="normalized_confusion_matrix", series="test", figure=fig, iteration=0)
|
| 67 |
+
plt.close(fig)
|
| 68 |
+
|
| 69 |
+
cfm_errors = cfm.copy()
|
| 70 |
+
np.fill_diagonal(cfm_errors, 0)
|
| 71 |
+
if cfm_errors.max() > 0:
|
| 72 |
+
fig_err, ax_err = plt.subplots(figsize=(18, 18))
|
| 73 |
+
im_err = ax_err.imshow(cfm_errors, interpolation="nearest", cmap=plt.cm.Blues)
|
| 74 |
+
cbar_err = fig_err.colorbar(im_err, ax=ax_err)
|
| 75 |
+
cbar_err.ax.set_ylabel("Number of misclassified samples", rotation=90)
|
| 76 |
+
ax_err.set_title("Confusion matrix (errors only)")
|
| 77 |
+
ax_err.set_xlabel("Predicted")
|
| 78 |
+
ax_err.set_ylabel("Ground Truth")
|
| 79 |
+
ax_err.set_xticks(np.arange(len(class_names)))
|
| 80 |
+
ax_err.set_yticks(np.arange(len(class_names)))
|
| 81 |
+
ax_err.set_xticklabels(class_names, rotation=90, fontsize=8)
|
| 82 |
+
ax_err.set_yticklabels(class_names, fontsize=8)
|
| 83 |
+
|
| 84 |
+
threshold = cfm_errors.max() / 2.0
|
| 85 |
+
for i in range(num_classes):
|
| 86 |
+
for j in range(num_classes):
|
| 87 |
+
value = cfm_errors[i, j]
|
| 88 |
+
if value == 0:
|
| 89 |
+
continue
|
| 90 |
+
ax_err.text(j, i, str(value), ha="center", va="center",
|
| 91 |
+
fontsize=5, color="white" if value > threshold else "black")
|
| 92 |
+
|
| 93 |
+
fig_err.tight_layout()
|
| 94 |
+
logger.report_matplotlib_figure(title="errors_only_confusion_matrix", series="test", figure=fig_err, iteration=0)
|
| 95 |
+
plt.close(fig_err)
|
| 96 |
+
else:
|
| 97 |
+
print("No misclassifications")
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
|