Spaces:
Sleeping
Sleeping
Upload github repo files
Browse filesSync up spaces repo with main github repo
- .gitignore +188 -0
- README.md +103 -13
- app.py +151 -0
- configs/train_cnn.yaml +24 -0
- notebooks/data_prep_hristina.ipynb +0 -0
- notebooks/eda.ipynb +0 -0
- process_dataset.py +31 -0
- requirements.txt +11 -0
- src/DataLoader/__init__.py +0 -0
- src/DataLoader/dataloader.py +44 -0
- src/DataLoader/plantvillage_dataset.py +50 -0
- src/DataLoader/utils.py +33 -0
- src/__pycache__/inference.cpython-312.pyc +0 -0
- src/evaluate.py +120 -0
- src/inference.py +289 -0
- src/models/__pycache__/cnn_model.cpython-312.pyc +0 -0
- src/models/__pycache__/resnet18_finetune.cpython-312.pyc +0 -0
- src/models/cnn_model.py +57 -0
- src/models/resnet18_finetune.py +20 -0
- src/train/early_stopping.py +14 -0
- src/train/train.py +163 -0
- src/utils/__pycache__/config.cpython-312.pyc +0 -0
- src/utils/class_names.py +41 -0
- src/utils/config.py +6 -0
- src/utils/metrics.py +9 -0
.gitignore
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Created by https://www.toptal.com/developers/gitignore/api/python
|
| 2 |
+
# Edit at https://www.toptal.com/developers/gitignore?templates=python
|
| 3 |
+
|
| 4 |
+
### Python ###
|
| 5 |
+
# Byte-compiled / optimized / DLL files
|
| 6 |
+
__pycache__/
|
| 7 |
+
*.py[cod]
|
| 8 |
+
*$py.class
|
| 9 |
+
|
| 10 |
+
# C extensions
|
| 11 |
+
*.so
|
| 12 |
+
|
| 13 |
+
# Distribution / packaging
|
| 14 |
+
.Python
|
| 15 |
+
build/
|
| 16 |
+
develop-eggs/
|
| 17 |
+
dist/
|
| 18 |
+
downloads/
|
| 19 |
+
eggs/
|
| 20 |
+
.eggs/
|
| 21 |
+
lib/
|
| 22 |
+
lib64/
|
| 23 |
+
parts/
|
| 24 |
+
sdist/
|
| 25 |
+
var/
|
| 26 |
+
wheels/
|
| 27 |
+
share/python-wheels/
|
| 28 |
+
*.egg-info/
|
| 29 |
+
.installed.cfg
|
| 30 |
+
*.egg
|
| 31 |
+
MANIFEST
|
| 32 |
+
|
| 33 |
+
# PyInstaller
|
| 34 |
+
# Usually these files are written by a python script from a template
|
| 35 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 36 |
+
*.manifest
|
| 37 |
+
*.spec
|
| 38 |
+
|
| 39 |
+
# Installer logs
|
| 40 |
+
pip-log.txt
|
| 41 |
+
pip-delete-this-directory.txt
|
| 42 |
+
|
| 43 |
+
# Unit test / coverage reports
|
| 44 |
+
htmlcov/
|
| 45 |
+
.tox/
|
| 46 |
+
.nox/
|
| 47 |
+
.coverage
|
| 48 |
+
.coverage.*
|
| 49 |
+
.cache
|
| 50 |
+
nosetests.xml
|
| 51 |
+
coverage.xml
|
| 52 |
+
*.cover
|
| 53 |
+
*.py,cover
|
| 54 |
+
.hypothesis/
|
| 55 |
+
.pytest_cache/
|
| 56 |
+
cover/
|
| 57 |
+
|
| 58 |
+
# Translations
|
| 59 |
+
*.mo
|
| 60 |
+
*.pot
|
| 61 |
+
|
| 62 |
+
# Django stuff:
|
| 63 |
+
*.log
|
| 64 |
+
local_settings.py
|
| 65 |
+
db.sqlite3
|
| 66 |
+
db.sqlite3-journal
|
| 67 |
+
|
| 68 |
+
# Flask stuff:
|
| 69 |
+
instance/
|
| 70 |
+
.webassets-cache
|
| 71 |
+
|
| 72 |
+
# Scrapy stuff:
|
| 73 |
+
.scrapy
|
| 74 |
+
|
| 75 |
+
# Sphinx documentation
|
| 76 |
+
docs/_build/
|
| 77 |
+
|
| 78 |
+
# PyBuilder
|
| 79 |
+
.pybuilder/
|
| 80 |
+
target/
|
| 81 |
+
|
| 82 |
+
# Jupyter Notebook
|
| 83 |
+
.ipynb_checkpoints
|
| 84 |
+
|
| 85 |
+
# IPython
|
| 86 |
+
profile_default/
|
| 87 |
+
ipython_config.py
|
| 88 |
+
|
| 89 |
+
# pyenv
|
| 90 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 91 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 92 |
+
# .python-version
|
| 93 |
+
|
| 94 |
+
# pipenv
|
| 95 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 96 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 97 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 98 |
+
# install all needed dependencies.
|
| 99 |
+
#Pipfile.lock
|
| 100 |
+
|
| 101 |
+
# poetry
|
| 102 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 103 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 104 |
+
# commonly ignored for libraries.
|
| 105 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 106 |
+
#poetry.lock
|
| 107 |
+
|
| 108 |
+
# pdm
|
| 109 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 110 |
+
#pdm.lock
|
| 111 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
| 112 |
+
# in version control.
|
| 113 |
+
# https://pdm.fming.dev/#use-with-ide
|
| 114 |
+
.pdm.toml
|
| 115 |
+
|
| 116 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 117 |
+
__pypackages__/
|
| 118 |
+
|
| 119 |
+
# Celery stuff
|
| 120 |
+
celerybeat-schedule
|
| 121 |
+
celerybeat.pid
|
| 122 |
+
|
| 123 |
+
# SageMath parsed files
|
| 124 |
+
*.sage.py
|
| 125 |
+
|
| 126 |
+
# Environments
|
| 127 |
+
.env
|
| 128 |
+
.venv
|
| 129 |
+
env/
|
| 130 |
+
venv/
|
| 131 |
+
ENV/
|
| 132 |
+
env.bak/
|
| 133 |
+
venv.bak/
|
| 134 |
+
|
| 135 |
+
# Spyder project settings
|
| 136 |
+
.spyderproject
|
| 137 |
+
.spyproject
|
| 138 |
+
|
| 139 |
+
# Rope project settings
|
| 140 |
+
.ropeproject
|
| 141 |
+
|
| 142 |
+
# mkdocs documentation
|
| 143 |
+
/site
|
| 144 |
+
|
| 145 |
+
# mypy
|
| 146 |
+
.mypy_cache/
|
| 147 |
+
.dmypy.json
|
| 148 |
+
dmypy.json
|
| 149 |
+
|
| 150 |
+
# Pyre type checker
|
| 151 |
+
.pyre/
|
| 152 |
+
|
| 153 |
+
# pytype static type analyzer
|
| 154 |
+
.pytype/
|
| 155 |
+
|
| 156 |
+
# Cython debug symbols
|
| 157 |
+
cython_debug/
|
| 158 |
+
|
| 159 |
+
# PyCharm
|
| 160 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 161 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 162 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 163 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 164 |
+
#.idea/
|
| 165 |
+
|
| 166 |
+
### Python Patch ###
|
| 167 |
+
# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
|
| 168 |
+
poetry.toml
|
| 169 |
+
|
| 170 |
+
# ruff
|
| 171 |
+
.ruff_cache/
|
| 172 |
+
|
| 173 |
+
# LSP config files
|
| 174 |
+
pyrightconfig.json
|
| 175 |
+
|
| 176 |
+
# End of https://www.toptal.com/developers/gitignore/api/python
|
| 177 |
+
|
| 178 |
+
# Data specific files/folders
|
| 179 |
+
data/
|
| 180 |
+
|
| 181 |
+
# Mac's stuff
|
| 182 |
+
.DS_Store
|
| 183 |
+
|
| 184 |
+
# Model
|
| 185 |
+
checkpoints/
|
| 186 |
+
|
| 187 |
+
# UI
|
| 188 |
+
flagged/
|
README.md
CHANGED
|
@@ -1,13 +1,103 @@
|
|
| 1 |
-
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
-
|
| 12 |
-
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 5CCSAGAP: AI/Robotics Group Project - Plant Disease Detection
|
| 2 |
+
|
| 3 |
+
This repository contains the source code for the "Plant Disease Detector" application, developed for the 5CCSAGAP module.
|
| 4 |
+
|
| 5 |
+
## 1. Project Overview
|
| 6 |
+
|
| 7 |
+
The goal of this project is to build an AI system capable of detecting plant diseases from leaf images. The application is built with a Python backend using PyTorch and is deployed with a Gradio user interface. The system supports two models: a custom-built CNN and a high-performance model based on ResNet18 transfer learning.
|
| 8 |
+
|
| 9 |
+
**Team 4 Members:**
|
| 10 |
+
- Oguzhan Cagirir
|
| 11 |
+
- Janit Bhardwaj
|
| 12 |
+
- Hristina Georgieva
|
| 13 |
+
- Hissan Omar
|
| 14 |
+
- Kasim Morsel
|
| 15 |
+
- Mark Soltyk
|
| 16 |
+
|
| 17 |
+
## 2. Running the Application
|
| 18 |
+
|
| 19 |
+
This guide is for running the Gradio application locally to demonstrate its functionality.
|
| 20 |
+
|
| 21 |
+
### Prerequisites
|
| 22 |
+
- Python 3.10+
|
| 23 |
+
- An active internet connection (for downloading model weights on first run).
|
| 24 |
+
- Git LFS is **not** required as large files are hosted externally.
|
| 25 |
+
|
| 26 |
+
### Setup Instructions
|
| 27 |
+
|
| 28 |
+
1. **Clone the repository:**
|
| 29 |
+
```bash
|
| 30 |
+
git clone https://github.kcl.ac.uk/k23136072/Small-group-project.git
|
| 31 |
+
cd Small-group-project
|
| 32 |
+
```
|
| 33 |
+
|
| 34 |
+
2. **Create and activate a Python virtual environment:**
|
| 35 |
+
|
| 36 |
+
**On Windows:**
|
| 37 |
+
```bash
|
| 38 |
+
python -m venv .venv
|
| 39 |
+
.venv\Scripts\activate
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
**On Linux/macOS:**
|
| 43 |
+
```bash
|
| 44 |
+
python3 -m venv .venv
|
| 45 |
+
source .venv/bin/activate
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
3. **Install dependencies:**
|
| 49 |
+
```bash
|
| 50 |
+
pip install -r requirements.txt
|
| 51 |
+
```
|
| 52 |
+
|
| 53 |
+
### Launching the Gradio UI
|
| 54 |
+
|
| 55 |
+
Once the setup is complete, launch the application:
|
| 56 |
+
|
| 57 |
+
```bash
|
| 58 |
+
python app.py
|
| 59 |
+
```
|
| 60 |
+
|
| 61 |
+
The application will start and provide a local URL (e.g., `http://127.0.0.1:7860`). Open this URL in your web browser.
|
| 62 |
+
|
| 63 |
+
**Note on First Run:** The first time you make a prediction with a specific model (e.g., "ResNet18"), the application will automatically download the required model weights (~45MB) from Hugging Face Hub. This may take a moment. Subsequent runs will use the cached local copy.
|
| 64 |
+
|
| 65 |
+
## 3. Repository Structure
|
| 66 |
+
|
| 67 |
+
- `app.py`: The main Gradio application entry point.
|
| 68 |
+
- `src/`: Contains all core source code.
|
| 69 |
+
- `src/inference.py`: The core API for running model predictions. Handles model downloading and preprocessing.
|
| 70 |
+
- `src/models/`: Contains the PyTorch model definitions (`cnn_model.py`, `resnet18_finetune.py`).
|
| 71 |
+
- `src/DataLoader/`: Contains the data loading and processing logic used for training.
|
| 72 |
+
- `src/train/`: Contains the training (`train.py`) and evaluation (`evaluate.py`) scripts.
|
| 73 |
+
- `src/utils/`: Configuration and helper utilities.
|
| 74 |
+
- `configs/`: Contains the YAML configuration file for training.
|
| 75 |
+
- `requirements.txt`: Project dependencies.
|
| 76 |
+
- `process_dataset.py`: Utility script to download and prepare the dataset from Hugging Face.
|
| 77 |
+
- `README.md`: This file.
|
| 78 |
+
|
| 79 |
+
## 4. Model Information
|
| 80 |
+
|
| 81 |
+
The application supports two models, selectable via the UI dropdown:
|
| 82 |
+
|
| 83 |
+
1. **ResNet18 (Default):** A high-performance model using transfer learning from a pretrained ResNet18. Achieves **~96% accuracy** on the test set.
|
| 84 |
+
2. **CNN:** A custom-built baseline CNN, trained from scratch. Achieves **~78% accuracy** on the test set.
|
| 85 |
+
|
| 86 |
+
Model weights are hosted on Hugging Face Hub at `MZaik/Plant_Disease_Detection` and are downloaded automatically by the application.
|
| 87 |
+
|
| 88 |
+
## 5. Training and Evaluation (Advanced)
|
| 89 |
+
|
| 90 |
+
To reproduce the training or evaluation runs, the repository is integrated with ClearML for experiment tracking.
|
| 91 |
+
|
| 92 |
+
**Training:**
|
| 93 |
+
1. Configure `configs/train_cnn.yaml` (set `model_type` to `resnet18` or `cnn`).
|
| 94 |
+
2. Run `python -m src.train.train`.
|
| 95 |
+
* This will automatically trigger data processing if the dataset is missing.
|
| 96 |
+
* It will execute on the configured ClearML server if credentials are present, or locally if not.
|
| 97 |
+
|
| 98 |
+
**Evaluation:**
|
| 99 |
+
1. Obtain a Model Task ID from a successful training run on ClearML.
|
| 100 |
+
2. Run `python -m src.evaluate --model_type resnet18 --task_id [TASK_ID]`.
|
| 101 |
+
* This generates the confusion matrix and classification report artifacts.
|
| 102 |
+
|
| 103 |
+
---
|
app.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import gradio as gr
|
| 3 |
+
import pandas as pd
|
| 4 |
+
from PIL import Image
|
| 5 |
+
|
| 6 |
+
from src.inference import predict_image
|
| 7 |
+
|
| 8 |
+
SAMPLE_IMGS = [
|
| 9 |
+
"plant_gallery/img1.jpg",
|
| 10 |
+
"plant_gallery/img2.jpg",
|
| 11 |
+
"plant_gallery/img3.jpg",
|
| 12 |
+
"plant_gallery/img4.jpg",
|
| 13 |
+
"plant_gallery/img5.jpg",
|
| 14 |
+
]
|
| 15 |
+
|
| 16 |
+
MODEL_REGISTRY = {
|
| 17 |
+
|
| 18 |
+
"ResNet18": {
|
| 19 |
+
"model_type": "resnet18",
|
| 20 |
+
"clearml_model_id": "", # Model ID can be left blank since task artifacts are in use within the KCL system.
|
| 21 |
+
"clearml_task_id": "85a9b1c01454493087b5c068a4f21ec6",
|
| 22 |
+
"hf_repo_id": "MZaik/Plant_Disease_Detection",
|
| 23 |
+
"hf_filename": "plant_disease_resnet18.pt",
|
| 24 |
+
"local_path": "checkpoints/best_resnet18.pt",
|
| 25 |
+
},
|
| 26 |
+
|
| 27 |
+
"CNN": {
|
| 28 |
+
"model_type": "cnn",
|
| 29 |
+
"clearml_model_id": "",
|
| 30 |
+
"clearml_task_id": "d86e57c96044410b8fd151c084b1d527",
|
| 31 |
+
"hf_repo_id": "MZaik/Plant_Disease_Detection",
|
| 32 |
+
"hf_filename": "plant_disease_cnn.pt",
|
| 33 |
+
"local_path": "checkpoints/best_cnn.pt",
|
| 34 |
+
},
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
def predict_batch(files, model_name: str, topk: int):
|
| 38 |
+
if not files:
|
| 39 |
+
return pd.DataFrame(columns=["Image", "Rank", "Disease", "Probability", "Model"])
|
| 40 |
+
|
| 41 |
+
cfg = MODEL_REGISTRY[model_name]
|
| 42 |
+
|
| 43 |
+
os.environ["MODEL_TYPE"] = cfg["model_type"]
|
| 44 |
+
os.environ["CLEARML_MODEL_ID"] = cfg.get("clearml_model_id", "")
|
| 45 |
+
os.environ["CLEARML_TASK_ID"] = cfg.get("clearml_task_id", "")
|
| 46 |
+
os.environ["HF_REPO_ID"] = cfg.get("hf_repo_id", "")
|
| 47 |
+
os.environ["HF_FILENAME"] = cfg.get("hf_filename", "")
|
| 48 |
+
os.environ["MODEL_PATH"] = cfg.get("local_path", "")
|
| 49 |
+
|
| 50 |
+
rows = []
|
| 51 |
+
for fp in files:
|
| 52 |
+
img = Image.open(fp).convert("RGB")
|
| 53 |
+
df = predict_image(img, k=topk)
|
| 54 |
+
df.insert(0, "Rank", range(1, len(df) + 1))
|
| 55 |
+
df.insert(0, "Image", os.path.basename(fp))
|
| 56 |
+
df["Model"] = model_name
|
| 57 |
+
rows.append(df)
|
| 58 |
+
|
| 59 |
+
return pd.concat(rows, ignore_index=True)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def add_gallery_to_files(evt: gr.SelectData, current_files):
|
| 63 |
+
# current_files will be list[str] of filepaths (or None)
|
| 64 |
+
if current_files is None:
|
| 65 |
+
current_files = []
|
| 66 |
+
selected_path = SAMPLE_IMGS[evt.index]
|
| 67 |
+
# avoid duplicates
|
| 68 |
+
if selected_path not in current_files:
|
| 69 |
+
current_files.append(selected_path)
|
| 70 |
+
return current_files
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def flag_current(files, model_name: str, topk: int, results_df: pd.DataFrame):
|
| 74 |
+
os.makedirs("flagged", exist_ok=True)
|
| 75 |
+
out_path = os.path.join("flagged", "flags.csv")
|
| 76 |
+
|
| 77 |
+
if not files or results_df is None or len(results_df) == 0:
|
| 78 |
+
return "Nothing to flag"
|
| 79 |
+
|
| 80 |
+
df = results_df.copy()
|
| 81 |
+
df.insert(0, "Files_Selected", ";".join(files))
|
| 82 |
+
df.insert(1, "Model_Selected", model_name)
|
| 83 |
+
df.insert(2, "TopK_Selected", int(topk))
|
| 84 |
+
|
| 85 |
+
header = not os.path.exists(out_path)
|
| 86 |
+
df.to_csv(out_path, mode="a", header=header, index=False)
|
| 87 |
+
|
| 88 |
+
return "Flag saved"
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
with gr.Blocks(title="Plant Disease Detector") as demo:
|
| 92 |
+
gr.Markdown("# 🌿 Plant Disease Detector")
|
| 93 |
+
|
| 94 |
+
with gr.Row():
|
| 95 |
+
files = gr.File(
|
| 96 |
+
label="Upload one or more images",
|
| 97 |
+
file_count="multiple",
|
| 98 |
+
file_types=["image"],
|
| 99 |
+
type="filepath",
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
with gr.Row():
|
| 103 |
+
model_choice = gr.Dropdown(
|
| 104 |
+
label="Model",
|
| 105 |
+
choices=list(MODEL_REGISTRY.keys()),
|
| 106 |
+
value="ResNet18",
|
| 107 |
+
)
|
| 108 |
+
topk = gr.Slider(1, 5, value=5, step=1, label="Top-K predictions")
|
| 109 |
+
|
| 110 |
+
with gr.Row():
|
| 111 |
+
predict_btn = gr.Button("Predict", variant="primary")
|
| 112 |
+
flag_btn = gr.Button("Flag")
|
| 113 |
+
clear_btn = gr.Button("Clear")
|
| 114 |
+
|
| 115 |
+
with gr.Row():
|
| 116 |
+
gallery = gr.Gallery(
|
| 117 |
+
label="Plant Gallery",
|
| 118 |
+
value=SAMPLE_IMGS,
|
| 119 |
+
columns=5,
|
| 120 |
+
height=200,
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
results = gr.DataFrame(
|
| 124 |
+
headers=["Image", "Rank", "Disease", "Probability", "Model"],
|
| 125 |
+
datatype=["str", "number", "str", "number", "str"],
|
| 126 |
+
interactive=False,
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
flag_status = gr.Textbox(label="Flag status", interactive=False)
|
| 130 |
+
|
| 131 |
+
gallery.select(
|
| 132 |
+
fn=add_gallery_to_files,
|
| 133 |
+
inputs=[files],
|
| 134 |
+
outputs=[files],
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
predict_btn.click(predict_batch, inputs=[files, model_choice, topk], outputs=results)
|
| 138 |
+
|
| 139 |
+
flag_btn.click(
|
| 140 |
+
fn=flag_current,
|
| 141 |
+
inputs=[files, model_choice, topk, results],
|
| 142 |
+
outputs=[flag_status],
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
clear_btn.click(
|
| 146 |
+
lambda: ([], "ResNet18", 5, pd.DataFrame(columns=["Image","Rank","Disease","Probability","Model"]), ""),
|
| 147 |
+
inputs=None,
|
| 148 |
+
outputs=[files, model_choice, topk, results, flag_status],
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
demo.launch()
|
configs/train_cnn.yaml
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
project: "PlantDisease"
|
| 2 |
+
task_name: "model_training"
|
| 3 |
+
|
| 4 |
+
# Model Selection: 'resnet18' or 'cnn'
|
| 5 |
+
model_type: "resnet18"
|
| 6 |
+
|
| 7 |
+
num_classes: 39
|
| 8 |
+
dropout: 0.5
|
| 9 |
+
|
| 10 |
+
lr: 0.001
|
| 11 |
+
weight_decay: 0.0001
|
| 12 |
+
epochs: 10
|
| 13 |
+
batch_size: 32
|
| 14 |
+
image_size: 256
|
| 15 |
+
|
| 16 |
+
patience: 3
|
| 17 |
+
min_delta: 0.001
|
| 18 |
+
save_last: true
|
| 19 |
+
|
| 20 |
+
data_path: "data/processed_plant_village"
|
| 21 |
+
train_samples_per_epoch : 38000
|
| 22 |
+
val_samples_per_epoch: 8000
|
| 23 |
+
test_samples_per_epoch: 8000
|
| 24 |
+
num_workers: 0
|
notebooks/data_prep_hristina.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
notebooks/eda.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
process_dataset.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datasets import load_dataset, DatasetDict
|
| 2 |
+
|
| 3 |
+
SEED = 21
|
| 4 |
+
|
| 5 |
+
print("Downloading dataset from hugging face...")
|
| 6 |
+
|
| 7 |
+
# Download the dataset from huggingface
|
| 8 |
+
ds = load_dataset("DScomp380/plant_village")
|
| 9 |
+
# extract dataset from DatasetDict
|
| 10 |
+
ds = ds['train']
|
| 11 |
+
|
| 12 |
+
print("Splitting dataset into train/test/validation...")
|
| 13 |
+
|
| 14 |
+
# First extract the training set
|
| 15 |
+
temp = ds.train_test_split(train_size=0.70, shuffle=True, seed=SEED)
|
| 16 |
+
# then split remaining dataset for test/validation
|
| 17 |
+
test_valid_ds = temp['test'].train_test_split(train_size=0.5, shuffle=True, seed=SEED)
|
| 18 |
+
|
| 19 |
+
# assign the sub datasets
|
| 20 |
+
train_ds = temp['train']
|
| 21 |
+
validation_ds = test_valid_ds['train']
|
| 22 |
+
test_ds = test_valid_ds['test']
|
| 23 |
+
|
| 24 |
+
# combine into one DatasetDict
|
| 25 |
+
ds_dict = DatasetDict({
|
| 26 |
+
"train": train_ds,
|
| 27 |
+
"test": test_ds,
|
| 28 |
+
"validation": validation_ds
|
| 29 |
+
})
|
| 30 |
+
|
| 31 |
+
ds_dict.save_to_disk("data/processed_plant_village")
|
requirements.txt
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
numpy<2.0
|
| 2 |
+
torch
|
| 3 |
+
pillow
|
| 4 |
+
datasets
|
| 5 |
+
matplotlib
|
| 6 |
+
torchvision
|
| 7 |
+
scikit-learn
|
| 8 |
+
seaborn
|
| 9 |
+
clearml
|
| 10 |
+
gradio
|
| 11 |
+
huggingface_hub
|
src/DataLoader/__init__.py
ADDED
|
File without changes
|
src/DataLoader/dataloader.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import datasets
|
| 3 |
+
from torch.utils.data import DataLoader, WeightedRandomSampler, BatchSampler
|
| 4 |
+
|
| 5 |
+
from src.DataLoader.plantvillage_dataset import PlantVillageDataset
|
| 6 |
+
from src.DataLoader.utils import calc_class_dist
|
| 7 |
+
|
| 8 |
+
def create_dataloader(dataset: datasets.Dataset, batch_size: int, samples_per_epoch: int, is_training_set: bool = True) -> DataLoader:
|
| 9 |
+
"""
|
| 10 |
+
Creates a new torch dataloader using given dataset and parameters.
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
dataset (datasets.Dataset): Dataset loaded using huggingface datasets library.
|
| 14 |
+
batch_size (int): Number of examples to sample per batch.
|
| 15 |
+
samples_per_epoch (int): Total number of examples to sample in one epoch.
|
| 16 |
+
is_training_set (bool): decides whether the given dataset should provide augmented images (Default is True).
|
| 17 |
+
|
| 18 |
+
Returns:
|
| 19 |
+
torch.utils.data.Dataloader: Returns a newly created dataloader.
|
| 20 |
+
|
| 21 |
+
Example:
|
| 22 |
+
`loader = create_dataloader(train_ds, 32, 1000)`
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
# Assign creation of torch-compatible dataset from hf dataset
|
| 26 |
+
torch_dataset = PlantVillageDataset(dataset, is_training_set)
|
| 27 |
+
|
| 28 |
+
# retrieve class percentages
|
| 29 |
+
class_percent = torch.tensor(calc_class_dist(dataset))
|
| 30 |
+
|
| 31 |
+
# calculate weight per class
|
| 32 |
+
class_weights = 1.0 / class_percent
|
| 33 |
+
class_weights = class_weights / class_weights.sum()
|
| 34 |
+
|
| 35 |
+
# assign class weights to each sample for the weighted sampler
|
| 36 |
+
labels = torch.tensor(dataset['label'])
|
| 37 |
+
sample_weights = class_weights[labels]
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# create sampler and dataloader
|
| 41 |
+
sampler = WeightedRandomSampler(sample_weights, replacement = True, num_samples=samples_per_epoch)
|
| 42 |
+
loader = DataLoader(torch_dataset, batch_sampler=BatchSampler(sampler, batch_size=batch_size, drop_last=True))
|
| 43 |
+
|
| 44 |
+
return loader
|
src/DataLoader/plantvillage_dataset.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch.utils.data import Dataset
|
| 2 |
+
import torchvision.transforms.v2 as T
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import datasets
|
| 5 |
+
import torch
|
| 6 |
+
from src.DataLoader.utils import rotate90
|
| 7 |
+
|
| 8 |
+
class PlantVillageDataset(Dataset):
|
| 9 |
+
def __init__(self, dataset: datasets.Dataset, is_training_set: bool):
|
| 10 |
+
self.dataset = dataset
|
| 11 |
+
|
| 12 |
+
self.num_classes = len(set(dataset['label']))
|
| 13 |
+
|
| 14 |
+
norm_transform = T.Compose([
|
| 15 |
+
T.Resize((256,256)),
|
| 16 |
+
T.ToImage(),
|
| 17 |
+
T.ToDtype(torch.float32, scale=True) # change values to range 0.0-1.0
|
| 18 |
+
])
|
| 19 |
+
|
| 20 |
+
if is_training_set:
|
| 21 |
+
|
| 22 |
+
augmentation = T.RandomApply([
|
| 23 |
+
T.ColorJitter(brightness=0.5, saturation=0.4),
|
| 24 |
+
T.RandomHorizontalFlip(p=0.5),
|
| 25 |
+
T.Lambda(rotate90), # wrap custom rotation function
|
| 26 |
+
T.GaussianBlur(kernel_size=3, sigma=(0.1, 1.0))
|
| 27 |
+
], p=0.5)
|
| 28 |
+
|
| 29 |
+
self.transform = T.Compose([
|
| 30 |
+
augmentation,
|
| 31 |
+
norm_transform
|
| 32 |
+
])
|
| 33 |
+
|
| 34 |
+
# no augmentation for test/validation sets
|
| 35 |
+
else:
|
| 36 |
+
self.transform = norm_transform
|
| 37 |
+
|
| 38 |
+
def __len__(self):
|
| 39 |
+
return len(self.dataset)
|
| 40 |
+
|
| 41 |
+
def __getitem__(self, idx):
|
| 42 |
+
|
| 43 |
+
img, label = self.dataset['image'][idx], self.dataset['label'][idx]
|
| 44 |
+
|
| 45 |
+
# augment/normalize image
|
| 46 |
+
img = self.transform(img)
|
| 47 |
+
|
| 48 |
+
label = torch.tensor(label)
|
| 49 |
+
|
| 50 |
+
return img, F.one_hot(label, num_classes=self.num_classes)
|
src/DataLoader/utils.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import datasets
|
| 2 |
+
import random
|
| 3 |
+
import torchvision.transforms.v2.functional as functional
|
| 4 |
+
from collections import Counter
|
| 5 |
+
|
| 6 |
+
def rotate90(image):
|
| 7 |
+
"""Rotate the image by a random multiple of 90 degrees"""
|
| 8 |
+
angle = 90 * random.randint(1,3)
|
| 9 |
+
return functional.rotate(image, angle=angle)
|
| 10 |
+
|
| 11 |
+
def calc_class_dist(dataset: datasets.Dataset) -> list[float]:
|
| 12 |
+
"""
|
| 13 |
+
Return percentage of total examples, done per class.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
# extract classes only
|
| 17 |
+
labels = dataset["label"]
|
| 18 |
+
counts = Counter(labels)
|
| 19 |
+
|
| 20 |
+
total_size = sum(counts.values())
|
| 21 |
+
percents = [100 * counts.get(i, 0) / total_size for i in range(max(labels)+1)]
|
| 22 |
+
|
| 23 |
+
return percents
|
| 24 |
+
|
| 25 |
+
def int_to_string(dataset: datasets.Dataset, int_label: int) -> str:
|
| 26 |
+
"""
|
| 27 |
+
Converts integer labels to their string counterpart.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
if not (0 <= int_label <= 38):
|
| 31 |
+
raise ValueError(f"Given label value, {int_label}, is out of range.")
|
| 32 |
+
|
| 33 |
+
return dataset.features['label'].int2str(int_label)
|
src/__pycache__/inference.cpython-312.pyc
ADDED
|
Binary file (13.2 kB). View file
|
|
|
src/evaluate.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
import argparse
|
| 5 |
+
import os
|
| 6 |
+
import sys
|
| 7 |
+
import subprocess
|
| 8 |
+
from sklearn.metrics import confusion_matrix, classification_report
|
| 9 |
+
import seaborn as sns
|
| 10 |
+
|
| 11 |
+
from src.DataLoader.dataloader import create_dataloader
|
| 12 |
+
from src.models.resnet18_finetune import make_resnet18
|
| 13 |
+
from src.models.cnn_model import PlantCNN
|
| 14 |
+
from src.utils.config import load_config
|
| 15 |
+
from datasets import load_from_disk
|
| 16 |
+
from clearml import Task, InputModel
|
| 17 |
+
|
| 18 |
+
def load_model_from_clearml(task_id: str, model_type: str, num_classes: int, device: str):
|
| 19 |
+
print(f"[INFO] Loading 'best_model' artifact from Task ID '{task_id}'...")
|
| 20 |
+
try:
|
| 21 |
+
source_task = Task.get_task(task_id=task_id)
|
| 22 |
+
model_path = source_task.artifacts['best_model'].get_local_copy()
|
| 23 |
+
print(f"[INFO] Model downloaded to: {model_path}")
|
| 24 |
+
except Exception as e:
|
| 25 |
+
print(f"[FATAL] Could not retrieve artifact from Task {task_id}. Error: {e}")
|
| 26 |
+
exit(1)
|
| 27 |
+
|
| 28 |
+
if model_type.lower() == 'resnet18':
|
| 29 |
+
model = make_resnet18(num_classes=num_classes)
|
| 30 |
+
elif model_type.lower() == 'cnn':
|
| 31 |
+
model = PlantCNN(num_classes=num_classes)
|
| 32 |
+
else:
|
| 33 |
+
raise ValueError(f"Unknown model type: {model_type}")
|
| 34 |
+
|
| 35 |
+
state_dict = torch.load(model_path, map_location=device)
|
| 36 |
+
if 'state_dict' in state_dict:
|
| 37 |
+
state_dict = state_dict['state_dict']
|
| 38 |
+
|
| 39 |
+
model.load_state_dict(state_dict)
|
| 40 |
+
model.to(device)
|
| 41 |
+
model.eval()
|
| 42 |
+
print("[SUCCESS] Model loaded and ready.")
|
| 43 |
+
return model
|
| 44 |
+
|
| 45 |
+
def evaluate_model(model, loader, device):
|
| 46 |
+
"""
|
| 47 |
+
Runs inference on the entire dataloader and returns predictions and labels.
|
| 48 |
+
"""
|
| 49 |
+
all_preds, all_labels = [], []
|
| 50 |
+
with torch.no_grad():
|
| 51 |
+
for inputs, labels in loader:
|
| 52 |
+
inputs = inputs.to(device)
|
| 53 |
+
if labels.ndim > 1:
|
| 54 |
+
labels = labels.argmax(dim=1)
|
| 55 |
+
|
| 56 |
+
outputs = model(inputs)
|
| 57 |
+
preds = outputs.argmax(dim=1).cpu().numpy()
|
| 58 |
+
|
| 59 |
+
all_preds.extend(preds)
|
| 60 |
+
all_labels.extend(labels.cpu().numpy())
|
| 61 |
+
return np.array(all_labels), np.array(all_preds)
|
| 62 |
+
|
| 63 |
+
def main():
|
| 64 |
+
task = Task.init(project_name="PlantDisease", task_name="model_evaluation", task_type=Task.TaskTypes.testing)
|
| 65 |
+
|
| 66 |
+
task.set_packages("./requirements.txt")
|
| 67 |
+
|
| 68 |
+
task.execute_remotely(queue_name="default")
|
| 69 |
+
|
| 70 |
+
parser = argparse.ArgumentParser(description="Evaluate a trained model from ClearML.")
|
| 71 |
+
parser.add_argument('--task_id', type=str, required=True, help="ClearML Task ID that produced the model.")
|
| 72 |
+
parser.add_argument('--model_type', type=str, required=True, choices=['resnet18', 'cnn'])
|
| 73 |
+
args = parser.parse_args()
|
| 74 |
+
|
| 75 |
+
task.connect(args)
|
| 76 |
+
logger = task.get_logger()
|
| 77 |
+
|
| 78 |
+
print(f"--- Evaluating Model from Task ID: {args.task_id} ({args.model_type.upper()}) ---")
|
| 79 |
+
|
| 80 |
+
cfg = load_config()
|
| 81 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 82 |
+
|
| 83 |
+
data_path = cfg['data_path']
|
| 84 |
+
if not os.path.exists(data_path):
|
| 85 |
+
print(f"[WARN] Data path '{data_path}' not found. Running processing script...")
|
| 86 |
+
subprocess.check_call([sys.executable, "process_dataset.py"])
|
| 87 |
+
|
| 88 |
+
ds_dict = load_from_disk(data_path)
|
| 89 |
+
test_loader = create_dataloader(ds_dict['test'], batch_size=32, samples_per_epoch=len(ds_dict['test']), is_training_set=False)
|
| 90 |
+
|
| 91 |
+
class_names = ds_dict['test'].features['label'].names
|
| 92 |
+
num_classes = len(class_names)
|
| 93 |
+
|
| 94 |
+
model = load_model_from_clearml(args.task_id, args.model_type, num_classes, device)
|
| 95 |
+
|
| 96 |
+
y_true, y_pred = evaluate_model(model, test_loader, device)
|
| 97 |
+
|
| 98 |
+
print("\n--- Generating Reports and Plots ---")
|
| 99 |
+
|
| 100 |
+
report_dict = classification_report(y_true, y_pred, target_names=class_names, zero_division=0, output_dict=True)
|
| 101 |
+
report_text = classification_report(y_true, y_pred, target_names=class_names, zero_division=0)
|
| 102 |
+
print(report_text)
|
| 103 |
+
task.upload_artifact(name="classification_report", artifact_object=report_dict)
|
| 104 |
+
|
| 105 |
+
cm = confusion_matrix(y_true, y_pred)
|
| 106 |
+
plt.figure(figsize=(22, 22))
|
| 107 |
+
sns.heatmap(cm, annot=False, cmap='Blues', xticklabels=class_names, yticklabels=class_names)
|
| 108 |
+
plt.ylabel('True Label', fontsize=14)
|
| 109 |
+
plt.xlabel('Predicted Label', fontsize=14)
|
| 110 |
+
plt.title(f'Confusion Matrix - {args.model_type.upper()}', fontsize=16)
|
| 111 |
+
plt.tight_layout()
|
| 112 |
+
|
| 113 |
+
logger.report_matplotlib_figure(title="Confusion Matrix", series=args.model_type, figure=plt, report_image=True)
|
| 114 |
+
|
| 115 |
+
print("[SUCCESS] Evaluation complete. Artifacts logged to ClearML.")
|
| 116 |
+
|
| 117 |
+
task.close()
|
| 118 |
+
|
| 119 |
+
if __name__ == "__main__":
|
| 120 |
+
main()
|
src/inference.py
ADDED
|
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
from typing import List
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import pandas as pd
|
| 8 |
+
from PIL import Image
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torchvision.transforms.v2 as T
|
| 13 |
+
|
| 14 |
+
from clearml import InputModel, Task
|
| 15 |
+
|
| 16 |
+
from huggingface_hub import hf_hub_download
|
| 17 |
+
|
| 18 |
+
from src.models.cnn_model import PlantCNN
|
| 19 |
+
from src.models.resnet18_finetune import make_resnet18
|
| 20 |
+
|
| 21 |
+
from src.utils.class_names import CLASS_NAMES
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
_MODEL_CACHE = {}
|
| 25 |
+
_CLASS_NAMES_CACHE = None
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def _device_key(device: torch.device) -> str:
|
| 29 |
+
return str(device)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _get_device() -> torch.device:
|
| 33 |
+
if torch.cuda.is_available():
|
| 34 |
+
return torch.device("cuda")
|
| 35 |
+
if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
|
| 36 |
+
return torch.device("mps")
|
| 37 |
+
return torch.device("cpu")
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _build_val_transform(image_size: int = 256) -> T.Compose:
|
| 41 |
+
return T.Compose([
|
| 42 |
+
T.Resize((image_size, image_size)),
|
| 43 |
+
T.ToImage(),
|
| 44 |
+
T.ToDtype(torch.float32, scale=True),
|
| 45 |
+
])
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def _load_model_from_checkpoint(
|
| 49 |
+
model_path: str,
|
| 50 |
+
num_classes: int,
|
| 51 |
+
model_type: str,
|
| 52 |
+
device: torch.device,
|
| 53 |
+
) -> nn.Module:
|
| 54 |
+
if not os.path.isfile(model_path):
|
| 55 |
+
raise FileNotFoundError(f"Model file not found - {model_path}")
|
| 56 |
+
|
| 57 |
+
ckpt = torch.load(model_path, map_location=device)
|
| 58 |
+
|
| 59 |
+
if model_type.lower() == "resnet18":
|
| 60 |
+
model = make_resnet18(num_classes=num_classes)
|
| 61 |
+
elif model_type.lower() == "cnn":
|
| 62 |
+
model = PlantCNN(num_classes=num_classes)
|
| 63 |
+
else:
|
| 64 |
+
raise ValueError(f"Unknown model type - {model_type}. Must be 'resnet18' or 'cnn'.")
|
| 65 |
+
|
| 66 |
+
if isinstance(ckpt, dict) and "state_dict" in ckpt:
|
| 67 |
+
model.load_state_dict(ckpt["state_dict"])
|
| 68 |
+
elif isinstance(ckpt, nn.Module):
|
| 69 |
+
model = ckpt
|
| 70 |
+
else:
|
| 71 |
+
try:
|
| 72 |
+
model.load_state_dict(ckpt)
|
| 73 |
+
except Exception:
|
| 74 |
+
raise ValueError(f"Unexpected checkpoint format in - {model_path}. ")
|
| 75 |
+
|
| 76 |
+
model.to(device)
|
| 77 |
+
model.eval()
|
| 78 |
+
|
| 79 |
+
return model
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def _load_model_from_clearml_model_id(
|
| 83 |
+
model_id: str,
|
| 84 |
+
num_classes: int,
|
| 85 |
+
model_type: str,
|
| 86 |
+
device: torch.device,
|
| 87 |
+
) -> nn.Module:
|
| 88 |
+
|
| 89 |
+
model_obj = InputModel(model_id=model_id)
|
| 90 |
+
downloaded_path = model_obj.get_local_copy()
|
| 91 |
+
|
| 92 |
+
if downloaded_path is None:
|
| 93 |
+
raise FileNotFoundError(f"Failed to download model from ClearML Model ID - {model_id}")
|
| 94 |
+
|
| 95 |
+
if os.path.isdir(downloaded_path):
|
| 96 |
+
model_files = [f for f in os.listdir(downloaded_path) if f.endswith((".pt", ".pth"))]
|
| 97 |
+
if model_files:
|
| 98 |
+
model_path = os.path.join(downloaded_path, model_files[0])
|
| 99 |
+
else:
|
| 100 |
+
for name in ["best_baseline.pt", "best_model.pt", "best_baseline.pth", "best_model.pth"]:
|
| 101 |
+
candidate = os.path.join(downloaded_path, name)
|
| 102 |
+
if os.path.isfile(candidate):
|
| 103 |
+
model_path = candidate
|
| 104 |
+
break
|
| 105 |
+
if model_path is None:
|
| 106 |
+
raise FileNotFoundError(f"No model file found in directory - {downloaded_path}")
|
| 107 |
+
else:
|
| 108 |
+
model_path = downloaded_path
|
| 109 |
+
|
| 110 |
+
if model_type.lower() == "resnet18":
|
| 111 |
+
model = make_resnet18(num_classes=num_classes)
|
| 112 |
+
elif model_type.lower() == "cnn":
|
| 113 |
+
model = PlantCNN(num_classes=num_classes)
|
| 114 |
+
else:
|
| 115 |
+
raise ValueError(f"Unknown model type - {model_type}. Must be 'resnet18' or 'cnn'.")
|
| 116 |
+
|
| 117 |
+
state_dict = torch.load(model_path, map_location=device)
|
| 118 |
+
if isinstance(state_dict, dict) and "state_dict" in state_dict:
|
| 119 |
+
state_dict = state_dict["state_dict"]
|
| 120 |
+
|
| 121 |
+
model.load_state_dict(state_dict)
|
| 122 |
+
model.to(device)
|
| 123 |
+
model.eval()
|
| 124 |
+
|
| 125 |
+
return model
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def _load_model_from_clearml_task_id(
|
| 129 |
+
task_id: str,
|
| 130 |
+
num_classes: int,
|
| 131 |
+
model_type: str,
|
| 132 |
+
device: torch.device,
|
| 133 |
+
) -> nn.Module:
|
| 134 |
+
|
| 135 |
+
source_task = Task.get_task(task_id=task_id)
|
| 136 |
+
artifact_names = ["best_model", "best_baseline", "model"]
|
| 137 |
+
model_path = None
|
| 138 |
+
|
| 139 |
+
for artifact_name in artifact_names:
|
| 140 |
+
if artifact_name in source_task.artifacts:
|
| 141 |
+
model_path = source_task.artifacts[artifact_name].get_local_copy()
|
| 142 |
+
if model_path:
|
| 143 |
+
break
|
| 144 |
+
|
| 145 |
+
if model_path is None:
|
| 146 |
+
raise FileNotFoundError(f"No model artifact found in Task ID - {task_id}")
|
| 147 |
+
|
| 148 |
+
if model_type.lower() == "resnet18":
|
| 149 |
+
model = make_resnet18(num_classes=num_classes)
|
| 150 |
+
elif model_type.lower() == "cnn":
|
| 151 |
+
model = PlantCNN(num_classes=num_classes)
|
| 152 |
+
else:
|
| 153 |
+
raise ValueError(f"Unknown model type - {model_type}. Must be 'resnet18' or 'cnn'.")
|
| 154 |
+
|
| 155 |
+
state_dict = torch.load(model_path, map_location=device)
|
| 156 |
+
if isinstance(state_dict, dict) and "state_dict" in state_dict:
|
| 157 |
+
state_dict = state_dict["state_dict"]
|
| 158 |
+
|
| 159 |
+
model.load_state_dict(state_dict)
|
| 160 |
+
model.to(device)
|
| 161 |
+
model.eval()
|
| 162 |
+
|
| 163 |
+
return model
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def _load_model_from_huggingface(
|
| 167 |
+
repo_id: str,
|
| 168 |
+
filename: str,
|
| 169 |
+
num_classes: int,
|
| 170 |
+
model_type: str,
|
| 171 |
+
device: torch.device,
|
| 172 |
+
) -> nn.Module:
|
| 173 |
+
|
| 174 |
+
model_path = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="model")
|
| 175 |
+
|
| 176 |
+
if model_type.lower() == "resnet18":
|
| 177 |
+
model = make_resnet18(num_classes=num_classes)
|
| 178 |
+
elif model_type.lower() == "cnn":
|
| 179 |
+
model = PlantCNN(num_classes=num_classes)
|
| 180 |
+
else:
|
| 181 |
+
raise ValueError(f"Unknown model type - {model_type}. Must be 'resnet18' or 'cnn'.")
|
| 182 |
+
|
| 183 |
+
state_dict = torch.load(model_path, map_location=device)
|
| 184 |
+
if isinstance(state_dict, dict) and "state_dict" in state_dict:
|
| 185 |
+
state_dict = state_dict["state_dict"]
|
| 186 |
+
|
| 187 |
+
model.load_state_dict(state_dict)
|
| 188 |
+
model.to(device)
|
| 189 |
+
model.eval()
|
| 190 |
+
|
| 191 |
+
return model
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def _get_class_names() -> List[str]:
|
| 196 |
+
return CLASS_NAMES
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def predict_image(img: Image.Image, k: int = 5) -> pd.DataFrame:
|
| 200 |
+
"""
|
| 201 |
+
Predict top-k for a single PIL image.
|
| 202 |
+
Returns a DataFrame with columns: Img, Rank, Disease, Probability, Model
|
| 203 |
+
"""
|
| 204 |
+
if img is None:
|
| 205 |
+
return pd.DataFrame({"Disease": [], "Probability": []})
|
| 206 |
+
|
| 207 |
+
try:
|
| 208 |
+
class_names = _get_class_names()
|
| 209 |
+
if not class_names:
|
| 210 |
+
raise ValueError("class_names list is empty.")
|
| 211 |
+
|
| 212 |
+
model_type = os.getenv("MODEL_TYPE", "resnet18")
|
| 213 |
+
model_path = os.getenv("MODEL_PATH", "")
|
| 214 |
+
clearml_model_id = os.getenv("CLEARML_MODEL_ID", "")
|
| 215 |
+
clearml_task_id = os.getenv("CLEARML_TASK_ID", "")
|
| 216 |
+
hf_repo_id = os.getenv("HF_REPO_ID", "")
|
| 217 |
+
hf_filename = os.getenv("HF_FILENAME", "")
|
| 218 |
+
|
| 219 |
+
device = _get_device()
|
| 220 |
+
device_k = _device_key(device)
|
| 221 |
+
num_classes = len(class_names)
|
| 222 |
+
transform = _build_val_transform(image_size=256)
|
| 223 |
+
|
| 224 |
+
x = transform(img.convert("RGB")).unsqueeze(0).to(device)
|
| 225 |
+
|
| 226 |
+
model = None
|
| 227 |
+
|
| 228 |
+
# ClearML Model ID
|
| 229 |
+
if clearml_model_id and clearml_model_id.strip():
|
| 230 |
+
cache_key = ("clearml_model", model_type, clearml_model_id, num_classes, device_k)
|
| 231 |
+
if cache_key not in _MODEL_CACHE:
|
| 232 |
+
try:
|
| 233 |
+
_MODEL_CACHE[cache_key] = _load_model_from_clearml_model_id(clearml_model_id, num_classes, model_type, device)
|
| 234 |
+
except Exception:
|
| 235 |
+
_MODEL_CACHE[cache_key] = None
|
| 236 |
+
model = _MODEL_CACHE.get(cache_key)
|
| 237 |
+
|
| 238 |
+
# ClearML Task ID
|
| 239 |
+
if model is None and clearml_task_id and clearml_task_id.strip():
|
| 240 |
+
cache_key = ("clearml_task", model_type, clearml_task_id, num_classes, device_k)
|
| 241 |
+
if cache_key not in _MODEL_CACHE:
|
| 242 |
+
try:
|
| 243 |
+
_MODEL_CACHE[cache_key] = _load_model_from_clearml_task_id(clearml_task_id, num_classes, model_type, device)
|
| 244 |
+
except Exception:
|
| 245 |
+
_MODEL_CACHE[cache_key] = None
|
| 246 |
+
model = _MODEL_CACHE.get(cache_key)
|
| 247 |
+
|
| 248 |
+
# Hugging Face
|
| 249 |
+
if model is None and hf_repo_id and hf_repo_id.strip() and hf_filename and hf_filename.strip():
|
| 250 |
+
cache_key = ("huggingface", model_type, hf_repo_id, hf_filename, num_classes, device_k)
|
| 251 |
+
if cache_key not in _MODEL_CACHE:
|
| 252 |
+
try:
|
| 253 |
+
_MODEL_CACHE[cache_key] = _load_model_from_huggingface(hf_repo_id, hf_filename, num_classes, model_type, device)
|
| 254 |
+
except Exception:
|
| 255 |
+
_MODEL_CACHE[cache_key] = None
|
| 256 |
+
model = _MODEL_CACHE.get(cache_key)
|
| 257 |
+
|
| 258 |
+
# Local checkpoint
|
| 259 |
+
if model is None:
|
| 260 |
+
if model_path and os.path.isfile(model_path):
|
| 261 |
+
cache_key = ("local", model_type, model_path, num_classes, device_k)
|
| 262 |
+
if cache_key not in _MODEL_CACHE:
|
| 263 |
+
_MODEL_CACHE[cache_key] = _load_model_from_checkpoint(model_path, num_classes, model_type, device)
|
| 264 |
+
model = _MODEL_CACHE[cache_key]
|
| 265 |
+
else:
|
| 266 |
+
raise FileNotFoundError(
|
| 267 |
+
f"All loading methods failed. Model ID - {clearml_model_id}, Task ID - {clearml_task_id}, HF - {hf_repo_id}/{hf_filename}, Local path - {model_path}"
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
with torch.no_grad():
|
| 272 |
+
logits = model(x)
|
| 273 |
+
probs = torch.softmax(logits, dim=1)[0]
|
| 274 |
+
|
| 275 |
+
topk = min(int(k), len(class_names))
|
| 276 |
+
top_probs, top_indices = torch.topk(probs, k=topk)
|
| 277 |
+
|
| 278 |
+
results = [
|
| 279 |
+
(class_names[idx.item()], float(prob.item()))
|
| 280 |
+
for prob, idx in zip(top_probs, top_indices)
|
| 281 |
+
]
|
| 282 |
+
|
| 283 |
+
return pd.DataFrame({
|
| 284 |
+
"Disease": [r[0] for r in results],
|
| 285 |
+
"Probability": [r[1] for r in results],
|
| 286 |
+
})
|
| 287 |
+
|
| 288 |
+
except Exception as e:
|
| 289 |
+
return pd.DataFrame({"Disease": [f"Error: {str(e)}"], "Probability": [0.0]})
|
src/models/__pycache__/cnn_model.cpython-312.pyc
ADDED
|
Binary file (3.15 kB). View file
|
|
|
src/models/__pycache__/resnet18_finetune.cpython-312.pyc
ADDED
|
Binary file (1.01 kB). View file
|
|
|
src/models/cnn_model.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
|
| 3 |
+
def conv_block(cin: int, cout: int, kernel_size: int = 3, padding: int = 1, p_drop: float = 0.1) -> nn.Sequential:
|
| 4 |
+
"""
|
| 5 |
+
A standard convolutional block comprising Conv2d, BatchNorm2d, ReLU, MaxPool2d, and Dropout2d.
|
| 6 |
+
This follows the well known best-practice of applying regularisation and downsampling within the feature extractor.
|
| 7 |
+
"""
|
| 8 |
+
return nn.Sequential(
|
| 9 |
+
nn.Conv2d(cin, cout, kernel_size=kernel_size, padding=padding, bias=False),
|
| 10 |
+
nn.BatchNorm2d(cout),
|
| 11 |
+
nn.ReLU(inplace=True),
|
| 12 |
+
nn.MaxPool2d(2),
|
| 13 |
+
nn.Dropout2d(p_drop)
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
class PlantCNN(nn.Module):
|
| 17 |
+
"""
|
| 18 |
+
A simple CNN architecture designed for the PlantVillage dataset.
|
| 19 |
+
|
| 20 |
+
This model is intentionally kept simple as a baseline.
|
| 21 |
+
|
| 22 |
+
It implements common-sense architectural choices:
|
| 23 |
+
|
| 24 |
+
1. Progressively increases channel depth (3 -> 32 -> 64 -> 128).
|
| 25 |
+
2. Reduces spatial resolution at each block via MaxPooling.
|
| 26 |
+
3. Uses a two-layer dense head for improved classification.
|
| 27 |
+
|
| 28 |
+
The model also includes an adaptive average pool.
|
| 29 |
+
"""
|
| 30 |
+
def __init__(self, num_classes: int = 39, p_drop: float = 0.5):
|
| 31 |
+
super().__init__()
|
| 32 |
+
|
| 33 |
+
# FE. Progressively increase the channel depth while halving spatial resolution.
|
| 34 |
+
self.features = nn.Sequential(
|
| 35 |
+
conv_block(3, 32),
|
| 36 |
+
conv_block(32, 64),
|
| 37 |
+
conv_block(64, 128),
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
# GAP. Creates fixed-size FV for the classifier head.
|
| 41 |
+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
| 42 |
+
|
| 43 |
+
# CH. Two-layer super dense head.
|
| 44 |
+
self.classifier = nn.Sequential(
|
| 45 |
+
nn.Flatten(),
|
| 46 |
+
nn.Linear(128, 64),
|
| 47 |
+
nn.ReLU(inplace=True),
|
| 48 |
+
nn.Dropout(p_drop),
|
| 49 |
+
nn.Linear(64, num_classes),
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
def forward(self, x):
|
| 53 |
+
"""Forward pass of the model."""
|
| 54 |
+
x = self.features(x)
|
| 55 |
+
x = self.avgpool(x)
|
| 56 |
+
x = self.classifier(x)
|
| 57 |
+
return x
|
src/models/resnet18_finetune.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
from torchvision import models
|
| 3 |
+
|
| 4 |
+
def make_resnet18(num_classes=39):
|
| 5 |
+
"""
|
| 6 |
+
Constructs a ResNet18 model with a custom classification head for PlantVillage.
|
| 7 |
+
The feature extractor weights are frozen (transfer learning).
|
| 8 |
+
"""
|
| 9 |
+
# Load pretrained model
|
| 10 |
+
# Note: In the lab we used weights=...V1. Here we use the same.
|
| 11 |
+
model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
|
| 12 |
+
|
| 13 |
+
# Freeze feature extractor
|
| 14 |
+
for param in model.parameters():
|
| 15 |
+
param.requires_grad = False
|
| 16 |
+
|
| 17 |
+
num_ftrs = model.fc.in_features
|
| 18 |
+
model.fc = nn.Linear(num_ftrs, num_classes)
|
| 19 |
+
|
| 20 |
+
return model
|
src/train/early_stopping.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
class EarlyStopping:
|
| 2 |
+
def __init__(self, patience=3, min_delta=0.0):
|
| 3 |
+
self.patience = patience
|
| 4 |
+
self.min_delta = float(min_delta)
|
| 5 |
+
self.best = None
|
| 6 |
+
self.count = 0
|
| 7 |
+
|
| 8 |
+
def step(self, metric):
|
| 9 |
+
if self.best is None or metric > self.best + self.min_delta:
|
| 10 |
+
self.best = metric
|
| 11 |
+
self.count = 0
|
| 12 |
+
return False
|
| 13 |
+
self.count += 1
|
| 14 |
+
return self.count >= self.patience
|
src/train/train.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import time
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from torch.optim import AdamW
|
| 6 |
+
from datasets import load_from_disk
|
| 7 |
+
import subprocess
|
| 8 |
+
import sys
|
| 9 |
+
|
| 10 |
+
# Import models
|
| 11 |
+
from src.models.resnet18_finetune import make_resnet18
|
| 12 |
+
from src.models.cnn_model import PlantCNN
|
| 13 |
+
|
| 14 |
+
# Import utils
|
| 15 |
+
from src.utils.config import load_config
|
| 16 |
+
from src.utils.metrics import accuracy, topk_accuracy
|
| 17 |
+
from src.train.early_stopping import EarlyStopping
|
| 18 |
+
|
| 19 |
+
# Import Dataloader
|
| 20 |
+
from src.DataLoader.dataloader import create_dataloader
|
| 21 |
+
|
| 22 |
+
def train_one_epoch(model, loader, criterion, opt, device):
|
| 23 |
+
model.train()
|
| 24 |
+
total_loss, total_correct, total_samples = 0.0, 0, 0
|
| 25 |
+
for inputs, labels in loader:
|
| 26 |
+
inputs = inputs.to(device)
|
| 27 |
+
|
| 28 |
+
# Loader might return one-hot labels. CrossEntropyLoss needs indices.
|
| 29 |
+
if labels.ndim > 1:
|
| 30 |
+
labels = labels.argmax(dim=1)
|
| 31 |
+
labels = labels.to(device).long()
|
| 32 |
+
|
| 33 |
+
opt.zero_grad(set_to_none=True)
|
| 34 |
+
logits = model(inputs)
|
| 35 |
+
loss = criterion(logits, labels)
|
| 36 |
+
loss.backward()
|
| 37 |
+
opt.step()
|
| 38 |
+
|
| 39 |
+
batch_size = inputs.size(0)
|
| 40 |
+
total_loss += loss.item() * batch_size
|
| 41 |
+
total_correct += (logits.argmax(1) == labels).sum().item()
|
| 42 |
+
total_samples += batch_size
|
| 43 |
+
return total_loss / total_samples, total_correct / total_samples
|
| 44 |
+
|
| 45 |
+
@torch.no_grad()
|
| 46 |
+
def evaluate(model, loader, criterion, device, topk=5):
|
| 47 |
+
model.eval()
|
| 48 |
+
total_loss, total_correct, total_topk, total_samples = 0.0, 0, 0, 0
|
| 49 |
+
for inputs, labels in loader:
|
| 50 |
+
inputs = inputs.to(device)
|
| 51 |
+
if labels.ndim > 1:
|
| 52 |
+
labels = labels.argmax(dim=1)
|
| 53 |
+
labels = labels.to(device).long()
|
| 54 |
+
|
| 55 |
+
logits = model(inputs)
|
| 56 |
+
loss = criterion(logits, labels)
|
| 57 |
+
|
| 58 |
+
batch_size = inputs.size(0)
|
| 59 |
+
total_loss += loss.item() * batch_size
|
| 60 |
+
total_correct += (logits.argmax(1) == labels).sum().item()
|
| 61 |
+
|
| 62 |
+
# Top-k
|
| 63 |
+
topk_preds = logits.topk(topk, dim=1).indices
|
| 64 |
+
total_topk += (topk_preds == labels.unsqueeze(1)).any(dim=1).sum().item()
|
| 65 |
+
total_samples += batch_size
|
| 66 |
+
return total_loss / total_samples, total_correct / total_samples, total_topk / total_samples
|
| 67 |
+
|
| 68 |
+
def main():
|
| 69 |
+
print("[INFO] Starting Integration Training Pipeline")
|
| 70 |
+
|
| 71 |
+
# 1. Config
|
| 72 |
+
cfg = load_config()
|
| 73 |
+
os.makedirs("checkpoints", exist_ok=True)
|
| 74 |
+
|
| 75 |
+
# 2. ClearML
|
| 76 |
+
try:
|
| 77 |
+
from clearml import Task
|
| 78 |
+
task = Task.init(project_name=cfg.get("project", "PlantDisease"), task_name=cfg.get("task_name", "model_training"))
|
| 79 |
+
task.set_packages("./requirements.txt")
|
| 80 |
+
task.execute_remotely(queue_name="default")
|
| 81 |
+
task.connect(cfg)
|
| 82 |
+
logger = task.get_logger()
|
| 83 |
+
print("[INFO] ClearML Initialized")
|
| 84 |
+
except ImportError:
|
| 85 |
+
logger = None
|
| 86 |
+
print("[INFO] ClearML not found, skipping logging")
|
| 87 |
+
|
| 88 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 89 |
+
print(f"[INFO] Device: {device}")
|
| 90 |
+
|
| 91 |
+
data_path = cfg['data_path']
|
| 92 |
+
if not os.path.exists(data_path):
|
| 93 |
+
print(f"[WARN] Data path '{data_path}' not found.")
|
| 94 |
+
print("[INFO] Attempting to run data processing script...")
|
| 95 |
+
try:
|
| 96 |
+
subprocess.check_call([sys.executable, "process_dataset.py"])
|
| 97 |
+
print("[SUCCESS] Data processing complete.")
|
| 98 |
+
except subprocess.CalledProcessError as e:
|
| 99 |
+
print(f"[FATAL] Data processing failed: {e}")
|
| 100 |
+
exit(1)
|
| 101 |
+
|
| 102 |
+
# 3. Data
|
| 103 |
+
print(f"[INFO] Loading data from {cfg['data_path']}")
|
| 104 |
+
ds_dict = load_from_disk(cfg['data_path'])
|
| 105 |
+
|
| 106 |
+
dl_train = create_dataloader(ds_dict['train'], cfg['batch_size'], cfg['train_samples_per_epoch'], True)
|
| 107 |
+
dl_val = create_dataloader(ds_dict['validation'], cfg['batch_size'], cfg['val_samples_per_epoch'], False)
|
| 108 |
+
dl_test = create_dataloader(ds_dict['test'], cfg['batch_size'], cfg['test_samples_per_epoch'], False)
|
| 109 |
+
|
| 110 |
+
# 4. Model Selection & Optimizer Setup
|
| 111 |
+
model_type = cfg.get('model_type', 'resnet18').lower()
|
| 112 |
+
print(f"[INFO] Initializing model architecture: {model_type}")
|
| 113 |
+
|
| 114 |
+
if model_type == 'resnet18':
|
| 115 |
+
model = make_resnet18(num_classes=cfg['num_classes'])
|
| 116 |
+
model = model.to(device)
|
| 117 |
+
# For ResNet transfer learning, we typically only optimize the head
|
| 118 |
+
opt = AdamW(model.fc.parameters(), lr=cfg['lr'], weight_decay=cfg['weight_decay'])
|
| 119 |
+
print("[INFO] Optimizer configured for ResNet head only.")
|
| 120 |
+
|
| 121 |
+
elif model_type == 'cnn':
|
| 122 |
+
model = PlantCNN(num_classes=cfg['num_classes'], p_drop=cfg.get('dropout', 0.5))
|
| 123 |
+
model = model.to(device)
|
| 124 |
+
# For custom CNN, we optimize all parameters
|
| 125 |
+
opt = AdamW(model.parameters(), lr=cfg['lr'], weight_decay=cfg['weight_decay'])
|
| 126 |
+
print("[INFO] Optimizer configured for full CNN parameters.")
|
| 127 |
+
|
| 128 |
+
else:
|
| 129 |
+
raise ValueError(f"Unknown model_type in config: {model_type}. Must be 'resnet18' or 'cnn'.")
|
| 130 |
+
|
| 131 |
+
# 5. Setup Loss & Stopper
|
| 132 |
+
crit = nn.CrossEntropyLoss()
|
| 133 |
+
stopper = EarlyStopping(patience=cfg['patience'], min_delta=cfg['min_delta'])
|
| 134 |
+
|
| 135 |
+
# 6. Loop
|
| 136 |
+
best_acc = 0.0
|
| 137 |
+
for epoch in range(1, cfg['epochs'] + 1):
|
| 138 |
+
train_loss, train_acc = train_one_epoch(model, dl_train, crit, opt, device)
|
| 139 |
+
val_loss, val_acc, val_top5 = evaluate(model, dl_val, crit, device, topk=5)
|
| 140 |
+
|
| 141 |
+
print(f"Epoch {epoch:02d} | Train Loss: {train_loss:.4f} Acc: {train_acc:.3f} | Val Loss: {val_loss:.4f} Acc: {val_acc:.3f} Top5: {val_top5:.3f}")
|
| 142 |
+
|
| 143 |
+
if logger:
|
| 144 |
+
logger.report_scalar("Loss", "train", train_loss, iteration=epoch)
|
| 145 |
+
logger.report_scalar("Accuracy", "train", train_acc, iteration=epoch)
|
| 146 |
+
logger.report_scalar("Loss", "val", val_loss, iteration=epoch)
|
| 147 |
+
logger.report_scalar("Accuracy", "val", val_acc, iteration=epoch)
|
| 148 |
+
|
| 149 |
+
if val_acc > best_acc:
|
| 150 |
+
best_acc = val_acc
|
| 151 |
+
torch.save(model.state_dict(), "checkpoints/best_baseline.pt")
|
| 152 |
+
|
| 153 |
+
if stopper.step(val_acc):
|
| 154 |
+
print("Early stopping.")
|
| 155 |
+
break
|
| 156 |
+
|
| 157 |
+
if logger:
|
| 158 |
+
print("[INFO] Uploading best model artifact to ClearML...")
|
| 159 |
+
task.upload_artifact(name="best_model", artifact_object="checkpoints/best_baseline.pt")
|
| 160 |
+
print("[SUCCESS] Model uploaded.")
|
| 161 |
+
|
| 162 |
+
if __name__ == "__main__":
|
| 163 |
+
main()
|
src/utils/__pycache__/config.cpython-312.pyc
ADDED
|
Binary file (572 Bytes). View file
|
|
|
src/utils/class_names.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
CLASS_NAMES = [
|
| 2 |
+
'Apple___Apple_scab',
|
| 3 |
+
'Apple___Black_rot',
|
| 4 |
+
'Apple___Cedar_apple_rust',
|
| 5 |
+
'Apple___healthy',
|
| 6 |
+
'Background_without_leaves',
|
| 7 |
+
'Blueberry___healthy',
|
| 8 |
+
'Cherry___Powdery_mildew',
|
| 9 |
+
'Cherry___healthy',
|
| 10 |
+
'Corn___Cercospora_leaf_spot Gray_leaf_spot',
|
| 11 |
+
'Corn___Common_rust',
|
| 12 |
+
'Corn___Northern_Leaf_Blight',
|
| 13 |
+
'Corn___healthy',
|
| 14 |
+
'Grape___Black_rot',
|
| 15 |
+
'Grape___Esca_(Black_Measles)',
|
| 16 |
+
'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)',
|
| 17 |
+
'Grape___healthy',
|
| 18 |
+
'Orange___Haunglongbing_(Citrus_greening)',
|
| 19 |
+
'Peach___Bacterial_spot',
|
| 20 |
+
'Peach___healthy',
|
| 21 |
+
'Pepper,_bell___Bacterial_spot',
|
| 22 |
+
'Pepper,_bell___healthy',
|
| 23 |
+
'Potato___Early_blight',
|
| 24 |
+
'Potato___Late_blight',
|
| 25 |
+
'Potato___healthy',
|
| 26 |
+
'Raspberry___healthy',
|
| 27 |
+
'Soybean___healthy',
|
| 28 |
+
'Squash___Powdery_mildew',
|
| 29 |
+
'Strawberry___Leaf_scorch',
|
| 30 |
+
'Strawberry___healthy',
|
| 31 |
+
'Tomato___Bacterial_spot',
|
| 32 |
+
'Tomato___Early_blight',
|
| 33 |
+
'Tomato___Late_blight',
|
| 34 |
+
'Tomato___Leaf_Mold',
|
| 35 |
+
'Tomato___Septoria_leaf_spot',
|
| 36 |
+
'Tomato___Spider_mites Two-spotted_spider_mite',
|
| 37 |
+
'Tomato___Target_Spot',
|
| 38 |
+
'Tomato___Tomato_Yellow_Leaf_Curl_Virus',
|
| 39 |
+
'Tomato___Tomato_mosaic_virus',
|
| 40 |
+
'Tomato___healthy'
|
| 41 |
+
]
|
src/utils/config.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import yaml
|
| 2 |
+
|
| 3 |
+
def load_config(config_path: str = "configs/train_cnn.yaml") -> dict:
|
| 4 |
+
"""Loads a YAML configuration file."""
|
| 5 |
+
with open(config_path) as f:
|
| 6 |
+
return yaml.safe_load(f)
|
src/utils/metrics.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
def accuracy(logits, y):
|
| 4 |
+
return (logits.argmax(1) == y).float().mean().item()
|
| 5 |
+
|
| 6 |
+
def topk_accuracy(logits, y, k=5):
|
| 7 |
+
topk = logits.topk(k, dim=1).indices
|
| 8 |
+
return (topk == y.unsqueeze(1)).any(dim=1).float().mean().item()
|
| 9 |
+
|