JAMM032 commited on
Commit
97fcc90
·
verified ·
1 Parent(s): 730decc

Upload github repo files

Browse files

Sync up spaces repo with main github repo

.gitignore ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Created by https://www.toptal.com/developers/gitignore/api/python
2
+ # Edit at https://www.toptal.com/developers/gitignore?templates=python
3
+
4
+ ### Python ###
5
+ # Byte-compiled / optimized / DLL files
6
+ __pycache__/
7
+ *.py[cod]
8
+ *$py.class
9
+
10
+ # C extensions
11
+ *.so
12
+
13
+ # Distribution / packaging
14
+ .Python
15
+ build/
16
+ develop-eggs/
17
+ dist/
18
+ downloads/
19
+ eggs/
20
+ .eggs/
21
+ lib/
22
+ lib64/
23
+ parts/
24
+ sdist/
25
+ var/
26
+ wheels/
27
+ share/python-wheels/
28
+ *.egg-info/
29
+ .installed.cfg
30
+ *.egg
31
+ MANIFEST
32
+
33
+ # PyInstaller
34
+ # Usually these files are written by a python script from a template
35
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
36
+ *.manifest
37
+ *.spec
38
+
39
+ # Installer logs
40
+ pip-log.txt
41
+ pip-delete-this-directory.txt
42
+
43
+ # Unit test / coverage reports
44
+ htmlcov/
45
+ .tox/
46
+ .nox/
47
+ .coverage
48
+ .coverage.*
49
+ .cache
50
+ nosetests.xml
51
+ coverage.xml
52
+ *.cover
53
+ *.py,cover
54
+ .hypothesis/
55
+ .pytest_cache/
56
+ cover/
57
+
58
+ # Translations
59
+ *.mo
60
+ *.pot
61
+
62
+ # Django stuff:
63
+ *.log
64
+ local_settings.py
65
+ db.sqlite3
66
+ db.sqlite3-journal
67
+
68
+ # Flask stuff:
69
+ instance/
70
+ .webassets-cache
71
+
72
+ # Scrapy stuff:
73
+ .scrapy
74
+
75
+ # Sphinx documentation
76
+ docs/_build/
77
+
78
+ # PyBuilder
79
+ .pybuilder/
80
+ target/
81
+
82
+ # Jupyter Notebook
83
+ .ipynb_checkpoints
84
+
85
+ # IPython
86
+ profile_default/
87
+ ipython_config.py
88
+
89
+ # pyenv
90
+ # For a library or package, you might want to ignore these files since the code is
91
+ # intended to run in multiple environments; otherwise, check them in:
92
+ # .python-version
93
+
94
+ # pipenv
95
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
96
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
97
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
98
+ # install all needed dependencies.
99
+ #Pipfile.lock
100
+
101
+ # poetry
102
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
103
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
104
+ # commonly ignored for libraries.
105
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
106
+ #poetry.lock
107
+
108
+ # pdm
109
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
110
+ #pdm.lock
111
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
112
+ # in version control.
113
+ # https://pdm.fming.dev/#use-with-ide
114
+ .pdm.toml
115
+
116
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
117
+ __pypackages__/
118
+
119
+ # Celery stuff
120
+ celerybeat-schedule
121
+ celerybeat.pid
122
+
123
+ # SageMath parsed files
124
+ *.sage.py
125
+
126
+ # Environments
127
+ .env
128
+ .venv
129
+ env/
130
+ venv/
131
+ ENV/
132
+ env.bak/
133
+ venv.bak/
134
+
135
+ # Spyder project settings
136
+ .spyderproject
137
+ .spyproject
138
+
139
+ # Rope project settings
140
+ .ropeproject
141
+
142
+ # mkdocs documentation
143
+ /site
144
+
145
+ # mypy
146
+ .mypy_cache/
147
+ .dmypy.json
148
+ dmypy.json
149
+
150
+ # Pyre type checker
151
+ .pyre/
152
+
153
+ # pytype static type analyzer
154
+ .pytype/
155
+
156
+ # Cython debug symbols
157
+ cython_debug/
158
+
159
+ # PyCharm
160
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
161
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
162
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
163
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
164
+ #.idea/
165
+
166
+ ### Python Patch ###
167
+ # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
168
+ poetry.toml
169
+
170
+ # ruff
171
+ .ruff_cache/
172
+
173
+ # LSP config files
174
+ pyrightconfig.json
175
+
176
+ # End of https://www.toptal.com/developers/gitignore/api/python
177
+
178
+ # Data specific files/folders
179
+ data/
180
+
181
+ # Mac's stuff
182
+ .DS_Store
183
+
184
+ # Model
185
+ checkpoints/
186
+
187
+ # UI
188
+ flagged/
README.md CHANGED
@@ -1,13 +1,103 @@
1
- ---
2
- title: Testing
3
- emoji: 🌍
4
- colorFrom: pink
5
- colorTo: blue
6
- sdk: gradio
7
- sdk_version: 6.1.0
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 5CCSAGAP: AI/Robotics Group Project - Plant Disease Detection
2
+
3
+ This repository contains the source code for the "Plant Disease Detector" application, developed for the 5CCSAGAP module.
4
+
5
+ ## 1. Project Overview
6
+
7
+ The goal of this project is to build an AI system capable of detecting plant diseases from leaf images. The application is built with a Python backend using PyTorch and is deployed with a Gradio user interface. The system supports two models: a custom-built CNN and a high-performance model based on ResNet18 transfer learning.
8
+
9
+ **Team 4 Members:**
10
+ - Oguzhan Cagirir
11
+ - Janit Bhardwaj
12
+ - Hristina Georgieva
13
+ - Hissan Omar
14
+ - Kasim Morsel
15
+ - Mark Soltyk
16
+
17
+ ## 2. Running the Application
18
+
19
+ This guide is for running the Gradio application locally to demonstrate its functionality.
20
+
21
+ ### Prerequisites
22
+ - Python 3.10+
23
+ - An active internet connection (for downloading model weights on first run).
24
+ - Git LFS is **not** required as large files are hosted externally.
25
+
26
+ ### Setup Instructions
27
+
28
+ 1. **Clone the repository:**
29
+ ```bash
30
+ git clone https://github.kcl.ac.uk/k23136072/Small-group-project.git
31
+ cd Small-group-project
32
+ ```
33
+
34
+ 2. **Create and activate a Python virtual environment:**
35
+
36
+ **On Windows:**
37
+ ```bash
38
+ python -m venv .venv
39
+ .venv\Scripts\activate
40
+ ```
41
+
42
+ **On Linux/macOS:**
43
+ ```bash
44
+ python3 -m venv .venv
45
+ source .venv/bin/activate
46
+ ```
47
+
48
+ 3. **Install dependencies:**
49
+ ```bash
50
+ pip install -r requirements.txt
51
+ ```
52
+
53
+ ### Launching the Gradio UI
54
+
55
+ Once the setup is complete, launch the application:
56
+
57
+ ```bash
58
+ python app.py
59
+ ```
60
+
61
+ The application will start and provide a local URL (e.g., `http://127.0.0.1:7860`). Open this URL in your web browser.
62
+
63
+ **Note on First Run:** The first time you make a prediction with a specific model (e.g., "ResNet18"), the application will automatically download the required model weights (~45MB) from Hugging Face Hub. This may take a moment. Subsequent runs will use the cached local copy.
64
+
65
+ ## 3. Repository Structure
66
+
67
+ - `app.py`: The main Gradio application entry point.
68
+ - `src/`: Contains all core source code.
69
+ - `src/inference.py`: The core API for running model predictions. Handles model downloading and preprocessing.
70
+ - `src/models/`: Contains the PyTorch model definitions (`cnn_model.py`, `resnet18_finetune.py`).
71
+ - `src/DataLoader/`: Contains the data loading and processing logic used for training.
72
+ - `src/train/`: Contains the training (`train.py`) and evaluation (`evaluate.py`) scripts.
73
+ - `src/utils/`: Configuration and helper utilities.
74
+ - `configs/`: Contains the YAML configuration file for training.
75
+ - `requirements.txt`: Project dependencies.
76
+ - `process_dataset.py`: Utility script to download and prepare the dataset from Hugging Face.
77
+ - `README.md`: This file.
78
+
79
+ ## 4. Model Information
80
+
81
+ The application supports two models, selectable via the UI dropdown:
82
+
83
+ 1. **ResNet18 (Default):** A high-performance model using transfer learning from a pretrained ResNet18. Achieves **~96% accuracy** on the test set.
84
+ 2. **CNN:** A custom-built baseline CNN, trained from scratch. Achieves **~78% accuracy** on the test set.
85
+
86
+ Model weights are hosted on Hugging Face Hub at `MZaik/Plant_Disease_Detection` and are downloaded automatically by the application.
87
+
88
+ ## 5. Training and Evaluation (Advanced)
89
+
90
+ To reproduce the training or evaluation runs, the repository is integrated with ClearML for experiment tracking.
91
+
92
+ **Training:**
93
+ 1. Configure `configs/train_cnn.yaml` (set `model_type` to `resnet18` or `cnn`).
94
+ 2. Run `python -m src.train.train`.
95
+ * This will automatically trigger data processing if the dataset is missing.
96
+ * It will execute on the configured ClearML server if credentials are present, or locally if not.
97
+
98
+ **Evaluation:**
99
+ 1. Obtain a Model Task ID from a successful training run on ClearML.
100
+ 2. Run `python -m src.evaluate --model_type resnet18 --task_id [TASK_ID]`.
101
+ * This generates the confusion matrix and classification report artifacts.
102
+
103
+ ---
app.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import pandas as pd
4
+ from PIL import Image
5
+
6
+ from src.inference import predict_image
7
+
8
+ SAMPLE_IMGS = [
9
+ "plant_gallery/img1.jpg",
10
+ "plant_gallery/img2.jpg",
11
+ "plant_gallery/img3.jpg",
12
+ "plant_gallery/img4.jpg",
13
+ "plant_gallery/img5.jpg",
14
+ ]
15
+
16
+ MODEL_REGISTRY = {
17
+
18
+ "ResNet18": {
19
+ "model_type": "resnet18",
20
+ "clearml_model_id": "", # Model ID can be left blank since task artifacts are in use within the KCL system.
21
+ "clearml_task_id": "85a9b1c01454493087b5c068a4f21ec6",
22
+ "hf_repo_id": "MZaik/Plant_Disease_Detection",
23
+ "hf_filename": "plant_disease_resnet18.pt",
24
+ "local_path": "checkpoints/best_resnet18.pt",
25
+ },
26
+
27
+ "CNN": {
28
+ "model_type": "cnn",
29
+ "clearml_model_id": "",
30
+ "clearml_task_id": "d86e57c96044410b8fd151c084b1d527",
31
+ "hf_repo_id": "MZaik/Plant_Disease_Detection",
32
+ "hf_filename": "plant_disease_cnn.pt",
33
+ "local_path": "checkpoints/best_cnn.pt",
34
+ },
35
+ }
36
+
37
+ def predict_batch(files, model_name: str, topk: int):
38
+ if not files:
39
+ return pd.DataFrame(columns=["Image", "Rank", "Disease", "Probability", "Model"])
40
+
41
+ cfg = MODEL_REGISTRY[model_name]
42
+
43
+ os.environ["MODEL_TYPE"] = cfg["model_type"]
44
+ os.environ["CLEARML_MODEL_ID"] = cfg.get("clearml_model_id", "")
45
+ os.environ["CLEARML_TASK_ID"] = cfg.get("clearml_task_id", "")
46
+ os.environ["HF_REPO_ID"] = cfg.get("hf_repo_id", "")
47
+ os.environ["HF_FILENAME"] = cfg.get("hf_filename", "")
48
+ os.environ["MODEL_PATH"] = cfg.get("local_path", "")
49
+
50
+ rows = []
51
+ for fp in files:
52
+ img = Image.open(fp).convert("RGB")
53
+ df = predict_image(img, k=topk)
54
+ df.insert(0, "Rank", range(1, len(df) + 1))
55
+ df.insert(0, "Image", os.path.basename(fp))
56
+ df["Model"] = model_name
57
+ rows.append(df)
58
+
59
+ return pd.concat(rows, ignore_index=True)
60
+
61
+
62
+ def add_gallery_to_files(evt: gr.SelectData, current_files):
63
+ # current_files will be list[str] of filepaths (or None)
64
+ if current_files is None:
65
+ current_files = []
66
+ selected_path = SAMPLE_IMGS[evt.index]
67
+ # avoid duplicates
68
+ if selected_path not in current_files:
69
+ current_files.append(selected_path)
70
+ return current_files
71
+
72
+
73
+ def flag_current(files, model_name: str, topk: int, results_df: pd.DataFrame):
74
+ os.makedirs("flagged", exist_ok=True)
75
+ out_path = os.path.join("flagged", "flags.csv")
76
+
77
+ if not files or results_df is None or len(results_df) == 0:
78
+ return "Nothing to flag"
79
+
80
+ df = results_df.copy()
81
+ df.insert(0, "Files_Selected", ";".join(files))
82
+ df.insert(1, "Model_Selected", model_name)
83
+ df.insert(2, "TopK_Selected", int(topk))
84
+
85
+ header = not os.path.exists(out_path)
86
+ df.to_csv(out_path, mode="a", header=header, index=False)
87
+
88
+ return "Flag saved"
89
+
90
+
91
+ with gr.Blocks(title="Plant Disease Detector") as demo:
92
+ gr.Markdown("# 🌿 Plant Disease Detector")
93
+
94
+ with gr.Row():
95
+ files = gr.File(
96
+ label="Upload one or more images",
97
+ file_count="multiple",
98
+ file_types=["image"],
99
+ type="filepath",
100
+ )
101
+
102
+ with gr.Row():
103
+ model_choice = gr.Dropdown(
104
+ label="Model",
105
+ choices=list(MODEL_REGISTRY.keys()),
106
+ value="ResNet18",
107
+ )
108
+ topk = gr.Slider(1, 5, value=5, step=1, label="Top-K predictions")
109
+
110
+ with gr.Row():
111
+ predict_btn = gr.Button("Predict", variant="primary")
112
+ flag_btn = gr.Button("Flag")
113
+ clear_btn = gr.Button("Clear")
114
+
115
+ with gr.Row():
116
+ gallery = gr.Gallery(
117
+ label="Plant Gallery",
118
+ value=SAMPLE_IMGS,
119
+ columns=5,
120
+ height=200,
121
+ )
122
+
123
+ results = gr.DataFrame(
124
+ headers=["Image", "Rank", "Disease", "Probability", "Model"],
125
+ datatype=["str", "number", "str", "number", "str"],
126
+ interactive=False,
127
+ )
128
+
129
+ flag_status = gr.Textbox(label="Flag status", interactive=False)
130
+
131
+ gallery.select(
132
+ fn=add_gallery_to_files,
133
+ inputs=[files],
134
+ outputs=[files],
135
+ )
136
+
137
+ predict_btn.click(predict_batch, inputs=[files, model_choice, topk], outputs=results)
138
+
139
+ flag_btn.click(
140
+ fn=flag_current,
141
+ inputs=[files, model_choice, topk, results],
142
+ outputs=[flag_status],
143
+ )
144
+
145
+ clear_btn.click(
146
+ lambda: ([], "ResNet18", 5, pd.DataFrame(columns=["Image","Rank","Disease","Probability","Model"]), ""),
147
+ inputs=None,
148
+ outputs=[files, model_choice, topk, results, flag_status],
149
+ )
150
+
151
+ demo.launch()
configs/train_cnn.yaml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ project: "PlantDisease"
2
+ task_name: "model_training"
3
+
4
+ # Model Selection: 'resnet18' or 'cnn'
5
+ model_type: "resnet18"
6
+
7
+ num_classes: 39
8
+ dropout: 0.5
9
+
10
+ lr: 0.001
11
+ weight_decay: 0.0001
12
+ epochs: 10
13
+ batch_size: 32
14
+ image_size: 256
15
+
16
+ patience: 3
17
+ min_delta: 0.001
18
+ save_last: true
19
+
20
+ data_path: "data/processed_plant_village"
21
+ train_samples_per_epoch : 38000
22
+ val_samples_per_epoch: 8000
23
+ test_samples_per_epoch: 8000
24
+ num_workers: 0
notebooks/data_prep_hristina.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
notebooks/eda.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
process_dataset.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset, DatasetDict
2
+
3
+ SEED = 21
4
+
5
+ print("Downloading dataset from hugging face...")
6
+
7
+ # Download the dataset from huggingface
8
+ ds = load_dataset("DScomp380/plant_village")
9
+ # extract dataset from DatasetDict
10
+ ds = ds['train']
11
+
12
+ print("Splitting dataset into train/test/validation...")
13
+
14
+ # First extract the training set
15
+ temp = ds.train_test_split(train_size=0.70, shuffle=True, seed=SEED)
16
+ # then split remaining dataset for test/validation
17
+ test_valid_ds = temp['test'].train_test_split(train_size=0.5, shuffle=True, seed=SEED)
18
+
19
+ # assign the sub datasets
20
+ train_ds = temp['train']
21
+ validation_ds = test_valid_ds['train']
22
+ test_ds = test_valid_ds['test']
23
+
24
+ # combine into one DatasetDict
25
+ ds_dict = DatasetDict({
26
+ "train": train_ds,
27
+ "test": test_ds,
28
+ "validation": validation_ds
29
+ })
30
+
31
+ ds_dict.save_to_disk("data/processed_plant_village")
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy<2.0
2
+ torch
3
+ pillow
4
+ datasets
5
+ matplotlib
6
+ torchvision
7
+ scikit-learn
8
+ seaborn
9
+ clearml
10
+ gradio
11
+ huggingface_hub
src/DataLoader/__init__.py ADDED
File without changes
src/DataLoader/dataloader.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import datasets
3
+ from torch.utils.data import DataLoader, WeightedRandomSampler, BatchSampler
4
+
5
+ from src.DataLoader.plantvillage_dataset import PlantVillageDataset
6
+ from src.DataLoader.utils import calc_class_dist
7
+
8
+ def create_dataloader(dataset: datasets.Dataset, batch_size: int, samples_per_epoch: int, is_training_set: bool = True) -> DataLoader:
9
+ """
10
+ Creates a new torch dataloader using given dataset and parameters.
11
+
12
+ Args:
13
+ dataset (datasets.Dataset): Dataset loaded using huggingface datasets library.
14
+ batch_size (int): Number of examples to sample per batch.
15
+ samples_per_epoch (int): Total number of examples to sample in one epoch.
16
+ is_training_set (bool): decides whether the given dataset should provide augmented images (Default is True).
17
+
18
+ Returns:
19
+ torch.utils.data.Dataloader: Returns a newly created dataloader.
20
+
21
+ Example:
22
+ `loader = create_dataloader(train_ds, 32, 1000)`
23
+ """
24
+
25
+ # Assign creation of torch-compatible dataset from hf dataset
26
+ torch_dataset = PlantVillageDataset(dataset, is_training_set)
27
+
28
+ # retrieve class percentages
29
+ class_percent = torch.tensor(calc_class_dist(dataset))
30
+
31
+ # calculate weight per class
32
+ class_weights = 1.0 / class_percent
33
+ class_weights = class_weights / class_weights.sum()
34
+
35
+ # assign class weights to each sample for the weighted sampler
36
+ labels = torch.tensor(dataset['label'])
37
+ sample_weights = class_weights[labels]
38
+
39
+
40
+ # create sampler and dataloader
41
+ sampler = WeightedRandomSampler(sample_weights, replacement = True, num_samples=samples_per_epoch)
42
+ loader = DataLoader(torch_dataset, batch_sampler=BatchSampler(sampler, batch_size=batch_size, drop_last=True))
43
+
44
+ return loader
src/DataLoader/plantvillage_dataset.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset
2
+ import torchvision.transforms.v2 as T
3
+ import torch.nn.functional as F
4
+ import datasets
5
+ import torch
6
+ from src.DataLoader.utils import rotate90
7
+
8
+ class PlantVillageDataset(Dataset):
9
+ def __init__(self, dataset: datasets.Dataset, is_training_set: bool):
10
+ self.dataset = dataset
11
+
12
+ self.num_classes = len(set(dataset['label']))
13
+
14
+ norm_transform = T.Compose([
15
+ T.Resize((256,256)),
16
+ T.ToImage(),
17
+ T.ToDtype(torch.float32, scale=True) # change values to range 0.0-1.0
18
+ ])
19
+
20
+ if is_training_set:
21
+
22
+ augmentation = T.RandomApply([
23
+ T.ColorJitter(brightness=0.5, saturation=0.4),
24
+ T.RandomHorizontalFlip(p=0.5),
25
+ T.Lambda(rotate90), # wrap custom rotation function
26
+ T.GaussianBlur(kernel_size=3, sigma=(0.1, 1.0))
27
+ ], p=0.5)
28
+
29
+ self.transform = T.Compose([
30
+ augmentation,
31
+ norm_transform
32
+ ])
33
+
34
+ # no augmentation for test/validation sets
35
+ else:
36
+ self.transform = norm_transform
37
+
38
+ def __len__(self):
39
+ return len(self.dataset)
40
+
41
+ def __getitem__(self, idx):
42
+
43
+ img, label = self.dataset['image'][idx], self.dataset['label'][idx]
44
+
45
+ # augment/normalize image
46
+ img = self.transform(img)
47
+
48
+ label = torch.tensor(label)
49
+
50
+ return img, F.one_hot(label, num_classes=self.num_classes)
src/DataLoader/utils.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datasets
2
+ import random
3
+ import torchvision.transforms.v2.functional as functional
4
+ from collections import Counter
5
+
6
+ def rotate90(image):
7
+ """Rotate the image by a random multiple of 90 degrees"""
8
+ angle = 90 * random.randint(1,3)
9
+ return functional.rotate(image, angle=angle)
10
+
11
+ def calc_class_dist(dataset: datasets.Dataset) -> list[float]:
12
+ """
13
+ Return percentage of total examples, done per class.
14
+ """
15
+
16
+ # extract classes only
17
+ labels = dataset["label"]
18
+ counts = Counter(labels)
19
+
20
+ total_size = sum(counts.values())
21
+ percents = [100 * counts.get(i, 0) / total_size for i in range(max(labels)+1)]
22
+
23
+ return percents
24
+
25
+ def int_to_string(dataset: datasets.Dataset, int_label: int) -> str:
26
+ """
27
+ Converts integer labels to their string counterpart.
28
+ """
29
+
30
+ if not (0 <= int_label <= 38):
31
+ raise ValueError(f"Given label value, {int_label}, is out of range.")
32
+
33
+ return dataset.features['label'].int2str(int_label)
src/__pycache__/inference.cpython-312.pyc ADDED
Binary file (13.2 kB). View file
 
src/evaluate.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ import argparse
5
+ import os
6
+ import sys
7
+ import subprocess
8
+ from sklearn.metrics import confusion_matrix, classification_report
9
+ import seaborn as sns
10
+
11
+ from src.DataLoader.dataloader import create_dataloader
12
+ from src.models.resnet18_finetune import make_resnet18
13
+ from src.models.cnn_model import PlantCNN
14
+ from src.utils.config import load_config
15
+ from datasets import load_from_disk
16
+ from clearml import Task, InputModel
17
+
18
+ def load_model_from_clearml(task_id: str, model_type: str, num_classes: int, device: str):
19
+ print(f"[INFO] Loading 'best_model' artifact from Task ID '{task_id}'...")
20
+ try:
21
+ source_task = Task.get_task(task_id=task_id)
22
+ model_path = source_task.artifacts['best_model'].get_local_copy()
23
+ print(f"[INFO] Model downloaded to: {model_path}")
24
+ except Exception as e:
25
+ print(f"[FATAL] Could not retrieve artifact from Task {task_id}. Error: {e}")
26
+ exit(1)
27
+
28
+ if model_type.lower() == 'resnet18':
29
+ model = make_resnet18(num_classes=num_classes)
30
+ elif model_type.lower() == 'cnn':
31
+ model = PlantCNN(num_classes=num_classes)
32
+ else:
33
+ raise ValueError(f"Unknown model type: {model_type}")
34
+
35
+ state_dict = torch.load(model_path, map_location=device)
36
+ if 'state_dict' in state_dict:
37
+ state_dict = state_dict['state_dict']
38
+
39
+ model.load_state_dict(state_dict)
40
+ model.to(device)
41
+ model.eval()
42
+ print("[SUCCESS] Model loaded and ready.")
43
+ return model
44
+
45
+ def evaluate_model(model, loader, device):
46
+ """
47
+ Runs inference on the entire dataloader and returns predictions and labels.
48
+ """
49
+ all_preds, all_labels = [], []
50
+ with torch.no_grad():
51
+ for inputs, labels in loader:
52
+ inputs = inputs.to(device)
53
+ if labels.ndim > 1:
54
+ labels = labels.argmax(dim=1)
55
+
56
+ outputs = model(inputs)
57
+ preds = outputs.argmax(dim=1).cpu().numpy()
58
+
59
+ all_preds.extend(preds)
60
+ all_labels.extend(labels.cpu().numpy())
61
+ return np.array(all_labels), np.array(all_preds)
62
+
63
+ def main():
64
+ task = Task.init(project_name="PlantDisease", task_name="model_evaluation", task_type=Task.TaskTypes.testing)
65
+
66
+ task.set_packages("./requirements.txt")
67
+
68
+ task.execute_remotely(queue_name="default")
69
+
70
+ parser = argparse.ArgumentParser(description="Evaluate a trained model from ClearML.")
71
+ parser.add_argument('--task_id', type=str, required=True, help="ClearML Task ID that produced the model.")
72
+ parser.add_argument('--model_type', type=str, required=True, choices=['resnet18', 'cnn'])
73
+ args = parser.parse_args()
74
+
75
+ task.connect(args)
76
+ logger = task.get_logger()
77
+
78
+ print(f"--- Evaluating Model from Task ID: {args.task_id} ({args.model_type.upper()}) ---")
79
+
80
+ cfg = load_config()
81
+ device = "cuda" if torch.cuda.is_available() else "cpu"
82
+
83
+ data_path = cfg['data_path']
84
+ if not os.path.exists(data_path):
85
+ print(f"[WARN] Data path '{data_path}' not found. Running processing script...")
86
+ subprocess.check_call([sys.executable, "process_dataset.py"])
87
+
88
+ ds_dict = load_from_disk(data_path)
89
+ test_loader = create_dataloader(ds_dict['test'], batch_size=32, samples_per_epoch=len(ds_dict['test']), is_training_set=False)
90
+
91
+ class_names = ds_dict['test'].features['label'].names
92
+ num_classes = len(class_names)
93
+
94
+ model = load_model_from_clearml(args.task_id, args.model_type, num_classes, device)
95
+
96
+ y_true, y_pred = evaluate_model(model, test_loader, device)
97
+
98
+ print("\n--- Generating Reports and Plots ---")
99
+
100
+ report_dict = classification_report(y_true, y_pred, target_names=class_names, zero_division=0, output_dict=True)
101
+ report_text = classification_report(y_true, y_pred, target_names=class_names, zero_division=0)
102
+ print(report_text)
103
+ task.upload_artifact(name="classification_report", artifact_object=report_dict)
104
+
105
+ cm = confusion_matrix(y_true, y_pred)
106
+ plt.figure(figsize=(22, 22))
107
+ sns.heatmap(cm, annot=False, cmap='Blues', xticklabels=class_names, yticklabels=class_names)
108
+ plt.ylabel('True Label', fontsize=14)
109
+ plt.xlabel('Predicted Label', fontsize=14)
110
+ plt.title(f'Confusion Matrix - {args.model_type.upper()}', fontsize=16)
111
+ plt.tight_layout()
112
+
113
+ logger.report_matplotlib_figure(title="Confusion Matrix", series=args.model_type, figure=plt, report_image=True)
114
+
115
+ print("[SUCCESS] Evaluation complete. Artifacts logged to ClearML.")
116
+
117
+ task.close()
118
+
119
+ if __name__ == "__main__":
120
+ main()
src/inference.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from typing import List
5
+
6
+ import numpy as np
7
+ import pandas as pd
8
+ from PIL import Image
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torchvision.transforms.v2 as T
13
+
14
+ from clearml import InputModel, Task
15
+
16
+ from huggingface_hub import hf_hub_download
17
+
18
+ from src.models.cnn_model import PlantCNN
19
+ from src.models.resnet18_finetune import make_resnet18
20
+
21
+ from src.utils.class_names import CLASS_NAMES
22
+
23
+
24
+ _MODEL_CACHE = {}
25
+ _CLASS_NAMES_CACHE = None
26
+
27
+
28
+ def _device_key(device: torch.device) -> str:
29
+ return str(device)
30
+
31
+
32
+ def _get_device() -> torch.device:
33
+ if torch.cuda.is_available():
34
+ return torch.device("cuda")
35
+ if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
36
+ return torch.device("mps")
37
+ return torch.device("cpu")
38
+
39
+
40
+ def _build_val_transform(image_size: int = 256) -> T.Compose:
41
+ return T.Compose([
42
+ T.Resize((image_size, image_size)),
43
+ T.ToImage(),
44
+ T.ToDtype(torch.float32, scale=True),
45
+ ])
46
+
47
+
48
+ def _load_model_from_checkpoint(
49
+ model_path: str,
50
+ num_classes: int,
51
+ model_type: str,
52
+ device: torch.device,
53
+ ) -> nn.Module:
54
+ if not os.path.isfile(model_path):
55
+ raise FileNotFoundError(f"Model file not found - {model_path}")
56
+
57
+ ckpt = torch.load(model_path, map_location=device)
58
+
59
+ if model_type.lower() == "resnet18":
60
+ model = make_resnet18(num_classes=num_classes)
61
+ elif model_type.lower() == "cnn":
62
+ model = PlantCNN(num_classes=num_classes)
63
+ else:
64
+ raise ValueError(f"Unknown model type - {model_type}. Must be 'resnet18' or 'cnn'.")
65
+
66
+ if isinstance(ckpt, dict) and "state_dict" in ckpt:
67
+ model.load_state_dict(ckpt["state_dict"])
68
+ elif isinstance(ckpt, nn.Module):
69
+ model = ckpt
70
+ else:
71
+ try:
72
+ model.load_state_dict(ckpt)
73
+ except Exception:
74
+ raise ValueError(f"Unexpected checkpoint format in - {model_path}. ")
75
+
76
+ model.to(device)
77
+ model.eval()
78
+
79
+ return model
80
+
81
+
82
+ def _load_model_from_clearml_model_id(
83
+ model_id: str,
84
+ num_classes: int,
85
+ model_type: str,
86
+ device: torch.device,
87
+ ) -> nn.Module:
88
+
89
+ model_obj = InputModel(model_id=model_id)
90
+ downloaded_path = model_obj.get_local_copy()
91
+
92
+ if downloaded_path is None:
93
+ raise FileNotFoundError(f"Failed to download model from ClearML Model ID - {model_id}")
94
+
95
+ if os.path.isdir(downloaded_path):
96
+ model_files = [f for f in os.listdir(downloaded_path) if f.endswith((".pt", ".pth"))]
97
+ if model_files:
98
+ model_path = os.path.join(downloaded_path, model_files[0])
99
+ else:
100
+ for name in ["best_baseline.pt", "best_model.pt", "best_baseline.pth", "best_model.pth"]:
101
+ candidate = os.path.join(downloaded_path, name)
102
+ if os.path.isfile(candidate):
103
+ model_path = candidate
104
+ break
105
+ if model_path is None:
106
+ raise FileNotFoundError(f"No model file found in directory - {downloaded_path}")
107
+ else:
108
+ model_path = downloaded_path
109
+
110
+ if model_type.lower() == "resnet18":
111
+ model = make_resnet18(num_classes=num_classes)
112
+ elif model_type.lower() == "cnn":
113
+ model = PlantCNN(num_classes=num_classes)
114
+ else:
115
+ raise ValueError(f"Unknown model type - {model_type}. Must be 'resnet18' or 'cnn'.")
116
+
117
+ state_dict = torch.load(model_path, map_location=device)
118
+ if isinstance(state_dict, dict) and "state_dict" in state_dict:
119
+ state_dict = state_dict["state_dict"]
120
+
121
+ model.load_state_dict(state_dict)
122
+ model.to(device)
123
+ model.eval()
124
+
125
+ return model
126
+
127
+
128
+ def _load_model_from_clearml_task_id(
129
+ task_id: str,
130
+ num_classes: int,
131
+ model_type: str,
132
+ device: torch.device,
133
+ ) -> nn.Module:
134
+
135
+ source_task = Task.get_task(task_id=task_id)
136
+ artifact_names = ["best_model", "best_baseline", "model"]
137
+ model_path = None
138
+
139
+ for artifact_name in artifact_names:
140
+ if artifact_name in source_task.artifacts:
141
+ model_path = source_task.artifacts[artifact_name].get_local_copy()
142
+ if model_path:
143
+ break
144
+
145
+ if model_path is None:
146
+ raise FileNotFoundError(f"No model artifact found in Task ID - {task_id}")
147
+
148
+ if model_type.lower() == "resnet18":
149
+ model = make_resnet18(num_classes=num_classes)
150
+ elif model_type.lower() == "cnn":
151
+ model = PlantCNN(num_classes=num_classes)
152
+ else:
153
+ raise ValueError(f"Unknown model type - {model_type}. Must be 'resnet18' or 'cnn'.")
154
+
155
+ state_dict = torch.load(model_path, map_location=device)
156
+ if isinstance(state_dict, dict) and "state_dict" in state_dict:
157
+ state_dict = state_dict["state_dict"]
158
+
159
+ model.load_state_dict(state_dict)
160
+ model.to(device)
161
+ model.eval()
162
+
163
+ return model
164
+
165
+
166
+ def _load_model_from_huggingface(
167
+ repo_id: str,
168
+ filename: str,
169
+ num_classes: int,
170
+ model_type: str,
171
+ device: torch.device,
172
+ ) -> nn.Module:
173
+
174
+ model_path = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="model")
175
+
176
+ if model_type.lower() == "resnet18":
177
+ model = make_resnet18(num_classes=num_classes)
178
+ elif model_type.lower() == "cnn":
179
+ model = PlantCNN(num_classes=num_classes)
180
+ else:
181
+ raise ValueError(f"Unknown model type - {model_type}. Must be 'resnet18' or 'cnn'.")
182
+
183
+ state_dict = torch.load(model_path, map_location=device)
184
+ if isinstance(state_dict, dict) and "state_dict" in state_dict:
185
+ state_dict = state_dict["state_dict"]
186
+
187
+ model.load_state_dict(state_dict)
188
+ model.to(device)
189
+ model.eval()
190
+
191
+ return model
192
+
193
+
194
+
195
+ def _get_class_names() -> List[str]:
196
+ return CLASS_NAMES
197
+
198
+
199
+ def predict_image(img: Image.Image, k: int = 5) -> pd.DataFrame:
200
+ """
201
+ Predict top-k for a single PIL image.
202
+ Returns a DataFrame with columns: Img, Rank, Disease, Probability, Model
203
+ """
204
+ if img is None:
205
+ return pd.DataFrame({"Disease": [], "Probability": []})
206
+
207
+ try:
208
+ class_names = _get_class_names()
209
+ if not class_names:
210
+ raise ValueError("class_names list is empty.")
211
+
212
+ model_type = os.getenv("MODEL_TYPE", "resnet18")
213
+ model_path = os.getenv("MODEL_PATH", "")
214
+ clearml_model_id = os.getenv("CLEARML_MODEL_ID", "")
215
+ clearml_task_id = os.getenv("CLEARML_TASK_ID", "")
216
+ hf_repo_id = os.getenv("HF_REPO_ID", "")
217
+ hf_filename = os.getenv("HF_FILENAME", "")
218
+
219
+ device = _get_device()
220
+ device_k = _device_key(device)
221
+ num_classes = len(class_names)
222
+ transform = _build_val_transform(image_size=256)
223
+
224
+ x = transform(img.convert("RGB")).unsqueeze(0).to(device)
225
+
226
+ model = None
227
+
228
+ # ClearML Model ID
229
+ if clearml_model_id and clearml_model_id.strip():
230
+ cache_key = ("clearml_model", model_type, clearml_model_id, num_classes, device_k)
231
+ if cache_key not in _MODEL_CACHE:
232
+ try:
233
+ _MODEL_CACHE[cache_key] = _load_model_from_clearml_model_id(clearml_model_id, num_classes, model_type, device)
234
+ except Exception:
235
+ _MODEL_CACHE[cache_key] = None
236
+ model = _MODEL_CACHE.get(cache_key)
237
+
238
+ # ClearML Task ID
239
+ if model is None and clearml_task_id and clearml_task_id.strip():
240
+ cache_key = ("clearml_task", model_type, clearml_task_id, num_classes, device_k)
241
+ if cache_key not in _MODEL_CACHE:
242
+ try:
243
+ _MODEL_CACHE[cache_key] = _load_model_from_clearml_task_id(clearml_task_id, num_classes, model_type, device)
244
+ except Exception:
245
+ _MODEL_CACHE[cache_key] = None
246
+ model = _MODEL_CACHE.get(cache_key)
247
+
248
+ # Hugging Face
249
+ if model is None and hf_repo_id and hf_repo_id.strip() and hf_filename and hf_filename.strip():
250
+ cache_key = ("huggingface", model_type, hf_repo_id, hf_filename, num_classes, device_k)
251
+ if cache_key not in _MODEL_CACHE:
252
+ try:
253
+ _MODEL_CACHE[cache_key] = _load_model_from_huggingface(hf_repo_id, hf_filename, num_classes, model_type, device)
254
+ except Exception:
255
+ _MODEL_CACHE[cache_key] = None
256
+ model = _MODEL_CACHE.get(cache_key)
257
+
258
+ # Local checkpoint
259
+ if model is None:
260
+ if model_path and os.path.isfile(model_path):
261
+ cache_key = ("local", model_type, model_path, num_classes, device_k)
262
+ if cache_key not in _MODEL_CACHE:
263
+ _MODEL_CACHE[cache_key] = _load_model_from_checkpoint(model_path, num_classes, model_type, device)
264
+ model = _MODEL_CACHE[cache_key]
265
+ else:
266
+ raise FileNotFoundError(
267
+ f"All loading methods failed. Model ID - {clearml_model_id}, Task ID - {clearml_task_id}, HF - {hf_repo_id}/{hf_filename}, Local path - {model_path}"
268
+ )
269
+
270
+
271
+ with torch.no_grad():
272
+ logits = model(x)
273
+ probs = torch.softmax(logits, dim=1)[0]
274
+
275
+ topk = min(int(k), len(class_names))
276
+ top_probs, top_indices = torch.topk(probs, k=topk)
277
+
278
+ results = [
279
+ (class_names[idx.item()], float(prob.item()))
280
+ for prob, idx in zip(top_probs, top_indices)
281
+ ]
282
+
283
+ return pd.DataFrame({
284
+ "Disease": [r[0] for r in results],
285
+ "Probability": [r[1] for r in results],
286
+ })
287
+
288
+ except Exception as e:
289
+ return pd.DataFrame({"Disease": [f"Error: {str(e)}"], "Probability": [0.0]})
src/models/__pycache__/cnn_model.cpython-312.pyc ADDED
Binary file (3.15 kB). View file
 
src/models/__pycache__/resnet18_finetune.cpython-312.pyc ADDED
Binary file (1.01 kB). View file
 
src/models/cnn_model.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ def conv_block(cin: int, cout: int, kernel_size: int = 3, padding: int = 1, p_drop: float = 0.1) -> nn.Sequential:
4
+ """
5
+ A standard convolutional block comprising Conv2d, BatchNorm2d, ReLU, MaxPool2d, and Dropout2d.
6
+ This follows the well known best-practice of applying regularisation and downsampling within the feature extractor.
7
+ """
8
+ return nn.Sequential(
9
+ nn.Conv2d(cin, cout, kernel_size=kernel_size, padding=padding, bias=False),
10
+ nn.BatchNorm2d(cout),
11
+ nn.ReLU(inplace=True),
12
+ nn.MaxPool2d(2),
13
+ nn.Dropout2d(p_drop)
14
+ )
15
+
16
+ class PlantCNN(nn.Module):
17
+ """
18
+ A simple CNN architecture designed for the PlantVillage dataset.
19
+
20
+ This model is intentionally kept simple as a baseline.
21
+
22
+ It implements common-sense architectural choices:
23
+
24
+ 1. Progressively increases channel depth (3 -> 32 -> 64 -> 128).
25
+ 2. Reduces spatial resolution at each block via MaxPooling.
26
+ 3. Uses a two-layer dense head for improved classification.
27
+
28
+ The model also includes an adaptive average pool.
29
+ """
30
+ def __init__(self, num_classes: int = 39, p_drop: float = 0.5):
31
+ super().__init__()
32
+
33
+ # FE. Progressively increase the channel depth while halving spatial resolution.
34
+ self.features = nn.Sequential(
35
+ conv_block(3, 32),
36
+ conv_block(32, 64),
37
+ conv_block(64, 128),
38
+ )
39
+
40
+ # GAP. Creates fixed-size FV for the classifier head.
41
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
42
+
43
+ # CH. Two-layer super dense head.
44
+ self.classifier = nn.Sequential(
45
+ nn.Flatten(),
46
+ nn.Linear(128, 64),
47
+ nn.ReLU(inplace=True),
48
+ nn.Dropout(p_drop),
49
+ nn.Linear(64, num_classes),
50
+ )
51
+
52
+ def forward(self, x):
53
+ """Forward pass of the model."""
54
+ x = self.features(x)
55
+ x = self.avgpool(x)
56
+ x = self.classifier(x)
57
+ return x
src/models/resnet18_finetune.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from torchvision import models
3
+
4
+ def make_resnet18(num_classes=39):
5
+ """
6
+ Constructs a ResNet18 model with a custom classification head for PlantVillage.
7
+ The feature extractor weights are frozen (transfer learning).
8
+ """
9
+ # Load pretrained model
10
+ # Note: In the lab we used weights=...V1. Here we use the same.
11
+ model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
12
+
13
+ # Freeze feature extractor
14
+ for param in model.parameters():
15
+ param.requires_grad = False
16
+
17
+ num_ftrs = model.fc.in_features
18
+ model.fc = nn.Linear(num_ftrs, num_classes)
19
+
20
+ return model
src/train/early_stopping.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class EarlyStopping:
2
+ def __init__(self, patience=3, min_delta=0.0):
3
+ self.patience = patience
4
+ self.min_delta = float(min_delta)
5
+ self.best = None
6
+ self.count = 0
7
+
8
+ def step(self, metric):
9
+ if self.best is None or metric > self.best + self.min_delta:
10
+ self.best = metric
11
+ self.count = 0
12
+ return False
13
+ self.count += 1
14
+ return self.count >= self.patience
src/train/train.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.optim import AdamW
6
+ from datasets import load_from_disk
7
+ import subprocess
8
+ import sys
9
+
10
+ # Import models
11
+ from src.models.resnet18_finetune import make_resnet18
12
+ from src.models.cnn_model import PlantCNN
13
+
14
+ # Import utils
15
+ from src.utils.config import load_config
16
+ from src.utils.metrics import accuracy, topk_accuracy
17
+ from src.train.early_stopping import EarlyStopping
18
+
19
+ # Import Dataloader
20
+ from src.DataLoader.dataloader import create_dataloader
21
+
22
+ def train_one_epoch(model, loader, criterion, opt, device):
23
+ model.train()
24
+ total_loss, total_correct, total_samples = 0.0, 0, 0
25
+ for inputs, labels in loader:
26
+ inputs = inputs.to(device)
27
+
28
+ # Loader might return one-hot labels. CrossEntropyLoss needs indices.
29
+ if labels.ndim > 1:
30
+ labels = labels.argmax(dim=1)
31
+ labels = labels.to(device).long()
32
+
33
+ opt.zero_grad(set_to_none=True)
34
+ logits = model(inputs)
35
+ loss = criterion(logits, labels)
36
+ loss.backward()
37
+ opt.step()
38
+
39
+ batch_size = inputs.size(0)
40
+ total_loss += loss.item() * batch_size
41
+ total_correct += (logits.argmax(1) == labels).sum().item()
42
+ total_samples += batch_size
43
+ return total_loss / total_samples, total_correct / total_samples
44
+
45
+ @torch.no_grad()
46
+ def evaluate(model, loader, criterion, device, topk=5):
47
+ model.eval()
48
+ total_loss, total_correct, total_topk, total_samples = 0.0, 0, 0, 0
49
+ for inputs, labels in loader:
50
+ inputs = inputs.to(device)
51
+ if labels.ndim > 1:
52
+ labels = labels.argmax(dim=1)
53
+ labels = labels.to(device).long()
54
+
55
+ logits = model(inputs)
56
+ loss = criterion(logits, labels)
57
+
58
+ batch_size = inputs.size(0)
59
+ total_loss += loss.item() * batch_size
60
+ total_correct += (logits.argmax(1) == labels).sum().item()
61
+
62
+ # Top-k
63
+ topk_preds = logits.topk(topk, dim=1).indices
64
+ total_topk += (topk_preds == labels.unsqueeze(1)).any(dim=1).sum().item()
65
+ total_samples += batch_size
66
+ return total_loss / total_samples, total_correct / total_samples, total_topk / total_samples
67
+
68
+ def main():
69
+ print("[INFO] Starting Integration Training Pipeline")
70
+
71
+ # 1. Config
72
+ cfg = load_config()
73
+ os.makedirs("checkpoints", exist_ok=True)
74
+
75
+ # 2. ClearML
76
+ try:
77
+ from clearml import Task
78
+ task = Task.init(project_name=cfg.get("project", "PlantDisease"), task_name=cfg.get("task_name", "model_training"))
79
+ task.set_packages("./requirements.txt")
80
+ task.execute_remotely(queue_name="default")
81
+ task.connect(cfg)
82
+ logger = task.get_logger()
83
+ print("[INFO] ClearML Initialized")
84
+ except ImportError:
85
+ logger = None
86
+ print("[INFO] ClearML not found, skipping logging")
87
+
88
+ device = "cuda" if torch.cuda.is_available() else "cpu"
89
+ print(f"[INFO] Device: {device}")
90
+
91
+ data_path = cfg['data_path']
92
+ if not os.path.exists(data_path):
93
+ print(f"[WARN] Data path '{data_path}' not found.")
94
+ print("[INFO] Attempting to run data processing script...")
95
+ try:
96
+ subprocess.check_call([sys.executable, "process_dataset.py"])
97
+ print("[SUCCESS] Data processing complete.")
98
+ except subprocess.CalledProcessError as e:
99
+ print(f"[FATAL] Data processing failed: {e}")
100
+ exit(1)
101
+
102
+ # 3. Data
103
+ print(f"[INFO] Loading data from {cfg['data_path']}")
104
+ ds_dict = load_from_disk(cfg['data_path'])
105
+
106
+ dl_train = create_dataloader(ds_dict['train'], cfg['batch_size'], cfg['train_samples_per_epoch'], True)
107
+ dl_val = create_dataloader(ds_dict['validation'], cfg['batch_size'], cfg['val_samples_per_epoch'], False)
108
+ dl_test = create_dataloader(ds_dict['test'], cfg['batch_size'], cfg['test_samples_per_epoch'], False)
109
+
110
+ # 4. Model Selection & Optimizer Setup
111
+ model_type = cfg.get('model_type', 'resnet18').lower()
112
+ print(f"[INFO] Initializing model architecture: {model_type}")
113
+
114
+ if model_type == 'resnet18':
115
+ model = make_resnet18(num_classes=cfg['num_classes'])
116
+ model = model.to(device)
117
+ # For ResNet transfer learning, we typically only optimize the head
118
+ opt = AdamW(model.fc.parameters(), lr=cfg['lr'], weight_decay=cfg['weight_decay'])
119
+ print("[INFO] Optimizer configured for ResNet head only.")
120
+
121
+ elif model_type == 'cnn':
122
+ model = PlantCNN(num_classes=cfg['num_classes'], p_drop=cfg.get('dropout', 0.5))
123
+ model = model.to(device)
124
+ # For custom CNN, we optimize all parameters
125
+ opt = AdamW(model.parameters(), lr=cfg['lr'], weight_decay=cfg['weight_decay'])
126
+ print("[INFO] Optimizer configured for full CNN parameters.")
127
+
128
+ else:
129
+ raise ValueError(f"Unknown model_type in config: {model_type}. Must be 'resnet18' or 'cnn'.")
130
+
131
+ # 5. Setup Loss & Stopper
132
+ crit = nn.CrossEntropyLoss()
133
+ stopper = EarlyStopping(patience=cfg['patience'], min_delta=cfg['min_delta'])
134
+
135
+ # 6. Loop
136
+ best_acc = 0.0
137
+ for epoch in range(1, cfg['epochs'] + 1):
138
+ train_loss, train_acc = train_one_epoch(model, dl_train, crit, opt, device)
139
+ val_loss, val_acc, val_top5 = evaluate(model, dl_val, crit, device, topk=5)
140
+
141
+ print(f"Epoch {epoch:02d} | Train Loss: {train_loss:.4f} Acc: {train_acc:.3f} | Val Loss: {val_loss:.4f} Acc: {val_acc:.3f} Top5: {val_top5:.3f}")
142
+
143
+ if logger:
144
+ logger.report_scalar("Loss", "train", train_loss, iteration=epoch)
145
+ logger.report_scalar("Accuracy", "train", train_acc, iteration=epoch)
146
+ logger.report_scalar("Loss", "val", val_loss, iteration=epoch)
147
+ logger.report_scalar("Accuracy", "val", val_acc, iteration=epoch)
148
+
149
+ if val_acc > best_acc:
150
+ best_acc = val_acc
151
+ torch.save(model.state_dict(), "checkpoints/best_baseline.pt")
152
+
153
+ if stopper.step(val_acc):
154
+ print("Early stopping.")
155
+ break
156
+
157
+ if logger:
158
+ print("[INFO] Uploading best model artifact to ClearML...")
159
+ task.upload_artifact(name="best_model", artifact_object="checkpoints/best_baseline.pt")
160
+ print("[SUCCESS] Model uploaded.")
161
+
162
+ if __name__ == "__main__":
163
+ main()
src/utils/__pycache__/config.cpython-312.pyc ADDED
Binary file (572 Bytes). View file
 
src/utils/class_names.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CLASS_NAMES = [
2
+ 'Apple___Apple_scab',
3
+ 'Apple___Black_rot',
4
+ 'Apple___Cedar_apple_rust',
5
+ 'Apple___healthy',
6
+ 'Background_without_leaves',
7
+ 'Blueberry___healthy',
8
+ 'Cherry___Powdery_mildew',
9
+ 'Cherry___healthy',
10
+ 'Corn___Cercospora_leaf_spot Gray_leaf_spot',
11
+ 'Corn___Common_rust',
12
+ 'Corn___Northern_Leaf_Blight',
13
+ 'Corn___healthy',
14
+ 'Grape___Black_rot',
15
+ 'Grape___Esca_(Black_Measles)',
16
+ 'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)',
17
+ 'Grape___healthy',
18
+ 'Orange___Haunglongbing_(Citrus_greening)',
19
+ 'Peach___Bacterial_spot',
20
+ 'Peach___healthy',
21
+ 'Pepper,_bell___Bacterial_spot',
22
+ 'Pepper,_bell___healthy',
23
+ 'Potato___Early_blight',
24
+ 'Potato___Late_blight',
25
+ 'Potato___healthy',
26
+ 'Raspberry___healthy',
27
+ 'Soybean___healthy',
28
+ 'Squash___Powdery_mildew',
29
+ 'Strawberry___Leaf_scorch',
30
+ 'Strawberry___healthy',
31
+ 'Tomato___Bacterial_spot',
32
+ 'Tomato___Early_blight',
33
+ 'Tomato___Late_blight',
34
+ 'Tomato___Leaf_Mold',
35
+ 'Tomato___Septoria_leaf_spot',
36
+ 'Tomato___Spider_mites Two-spotted_spider_mite',
37
+ 'Tomato___Target_Spot',
38
+ 'Tomato___Tomato_Yellow_Leaf_Curl_Virus',
39
+ 'Tomato___Tomato_mosaic_virus',
40
+ 'Tomato___healthy'
41
+ ]
src/utils/config.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import yaml
2
+
3
+ def load_config(config_path: str = "configs/train_cnn.yaml") -> dict:
4
+ """Loads a YAML configuration file."""
5
+ with open(config_path) as f:
6
+ return yaml.safe_load(f)
src/utils/metrics.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def accuracy(logits, y):
4
+ return (logits.argmax(1) == y).float().mean().item()
5
+
6
+ def topk_accuracy(logits, y, k=5):
7
+ topk = logits.topk(k, dim=1).indices
8
+ return (topk == y.unsqueeze(1)).any(dim=1).float().mean().item()
9
+