unionpoint commited on
Commit
5d2fa0b
·
verified ·
1 Parent(s): d09450b

Upload folder using huggingface_hub

Browse files
.gitignore ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # Installer logs
30
+ pip-log.txt
31
+ pip-delete-this-directory.txt
32
+
33
+ # Unit test / coverage reports
34
+ htmlcov/
35
+ .tox/
36
+ .nox/
37
+ .coverage
38
+ .coverage.*
39
+ .cache
40
+ nosetests.xml
41
+ coverage.xml
42
+ *.cover
43
+ *.pyo
44
+ .hypothesis/
45
+ .pytest_cache/
46
+ cover/
47
+
48
+ # Translation files
49
+ *.mo
50
+ *.pot
51
+
52
+ # Logs
53
+ *.log
54
+ logs/
55
+
56
+ # Django stuff (common patterns):
57
+ local_settings.py
58
+ db.sqlite3
59
+ db.sqlite3-journal
60
+
61
+ # Environments
62
+ .env
63
+ .venv
64
+ env/
65
+ venv/
66
+ ENV/
67
+ env.bak/
68
+ venv.bak/
69
+
70
+ # Project Specific Directories (Data & Models)
71
+ .chromadb/
72
+ .chromadb_test/
73
+ whisper-small-hy-ct2/
74
+ data/raw/
75
+ data/processed/
76
+ models/
77
+ task.pdf
78
+ plan.md
79
+
80
+ # Tool Caches
81
+ .ruff_cache/
82
+ .mypy_cache/
83
+
84
+ # IDEs
85
+ .vscode/
86
+ .idea/
87
+ .DS_Store
88
+ .history/
89
+
90
+ # OS-specific
91
+ Thumbs.db
92
+ Desktop.ini
README.md ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Plant Disease Classification
2
+
3
+ A robust, configurable deep learning pipeline for plant disease classification using PyTorch. This project leverages `timm` for a vast array of pre-trained backbones (e.g., EfficientNetV2, ConvNeXtV2, EVA02) and offers advanced training features such as Exponential Moving Average (EMA) for weights, Layer-wise Learning Rate Decay (LLRD), MixUp/CutMix data augmentation, and Weights & Biases (W&B) integration for experiment tracking.
4
+
5
+
6
+ - **Web Interface:** [](https://huggingface.co/spaces/)
7
+ - **REST API Documentation:** [s]()
8
+ ## Features
9
+
10
+ - **Extensive Model Support**: Easily swap backbones by changing the config, enabled by integration with `timm`.
11
+ - **Advanced Training Techniques**:
12
+ - Model EMA (Exponential Moving Average) to stabilize training and improve generalization.
13
+ - Layer-wise Learning Rate Decay (LLRD) for optimal fine-tuning of transformer and CNN architectures like `vit`, `convnextv2`.
14
+ - Mixed Precision Training for faster execution and lower memory footprint.
15
+ - Gradient Accumulation.
16
+ - **Data Augmentation**: MixUp and CutMix integrations for regularization.
17
+ - **Customizable Configuration**: Highly modular experiment setups using `omegaconf` (YAML config files).
18
+ - **Experiment Tracking**: Full integration with Weights & Biases logging everything from hyperparameter configs to validation metrics.
19
+
20
+ ## Results
21
+
22
+ | Model | mAP | Accuracy |
23
+ | :--- | :---: | :---: |
24
+ | EfficientNetV2 Small | 0.87 | 0.815 |
25
+ | DINOv3 ViT Small Plus | 0.91 | 0.830 |
26
+ | ConvNeXtV2 Tiny | 0.94 | 0.860 |
27
+
28
+ ## Project Structure
29
+
30
+ ```
31
+ Plant-Disease-Classification/
32
+ ├── configs/
33
+ │ └── config.yaml # Main configuration file
34
+ ├── data/
35
+ │ ├── train/ # Train data (organized by class folders)
36
+ │ └── val/ # Val data (organized by class folders)
37
+ ├── src/
38
+ │ ├── dataset.py # Dataloaders and augmentation logic
39
+ │ ├── infer.py # Inference script and prediction utilities
40
+ │ ├── loss.py # Loss functions (CrossEntropy, Focal Loss)
41
+ │ ├── metrics.py # Metric calculations
42
+ │ ├── models.py # Model definitions and param groupings
43
+ │ ├── trainer.py # Core training loop
44
+ │ └── utils.py # Helpers (schedulers, seeds, config loading)
45
+ ├── train.py # Main entrypoint for training
46
+ └── requirements.txt # Project dependencies
47
+ ```
48
+
49
+ ## Quick Start
50
+
51
+ ### 1. Environment Setup
52
+
53
+ It is highly recommended to use [`uv`](https://github.com/astral-sh/uv) for fast, reliable package management.
54
+
55
+ ```bash
56
+ # Create a virtual environment using uv
57
+ uv venv
58
+
59
+ # Activate the environment
60
+ source .venv/bin/activate # Linux/MacOS
61
+
62
+ # Install dependencies rapidly
63
+ uv pip install -r requirements.txt
64
+ ```
65
+
66
+ ### 2. Prepare Data
67
+
68
+ Ensure your dataset is arranged in PyTorch `ImageFolder` format. Place the training data in `data/train` and validation data in `data/val`. Each subplot or leaf should be in its corresponding disease or health category folder.
69
+
70
+ ```text
71
+ data/
72
+ └── train/
73
+ ├── Apple scab/
74
+ └── ...
75
+ ```
76
+
77
+ ### 3. Provide Configuration
78
+
79
+ Modify the hyperparameters, model choices, and paths inside `configs/config.yaml`.
80
+
81
+
82
+ ### 4. Train the Model
83
+
84
+ Run the training pipeline:
85
+
86
+ ```bash
87
+ python train.py --config configs/config.yaml
88
+ ```
89
+
90
+ **Resuming Training**:
91
+ To resume from an existing checkpoint, pass the `--resume` argument:
92
+ ```bash
93
+ python train.py --config configs/config.yaml --resume checkpoints/checkpoint.pth
94
+ ```
95
+
96
+ To load weights for a warm start (e.g., finetuning), use:
97
+ ```bash
98
+ python train.py --config configs/config.yaml --init_weights weights/pretrained.pth
99
+ ```
100
+
101
+ ### 5. Inference
102
+
103
+ You can run inference on a single image using the `src/infer.py` script. The script requires a serialized TorchScript model checkpoint.
104
+
105
+ ```bash
106
+ # Basic inference
107
+ python src/infer.py --image_path path/to/leaf.jpg --checkpoint checkpoints/best_model.pt --image_size 384
108
+
109
+ # Inference with Test Time Augmentation (TTA)
110
+ python src/infer.py --image_path path/to/leaf.jpg --checkpoint checkpoints/best_model.pt --image_size 384 --tta
111
+ ```
112
+
113
+ > **Note**: The inference script expects a `data/label_map.json` file to map class indices to disease names.
114
+
115
+ ## Documentation
116
+
117
+ ### Model Selection
118
+ By default, the pipeline uses `timm.create_model(...)`. You can specify any model architecture available in `timm` (e.g. `convnextv2_base`, `efficientnet_b0`, `eva02_base_patch14_448`) directly in the `config.yaml` file under `model.backbone`.
119
+
120
+ ### Configuration Details
121
+ The pipeline uses `OmegaConf`. Hyperparameters such as `loss`, `optimizer`, and `augmentation` can be tweaked. For example, to enable layer-wise learning rate decay, adjust `optimizer.layer_decay` to a value `< 1.0`.
122
+
123
+ ### Logging & Checkpoints
124
+ - Checkpoints are saved under the `checkpoints/` directory (customizable via `logging.checkpoint_dir`).
125
+ - Best model checkpoints (current and EMA) are tracked based on the monitored validation metric.
126
+ - When `logging.use_wandb` is true, the script initializes a Weights & Biases run, logging train/validation losses and selected metrics seamlessly.
127
+
128
+ ## Model Weights
129
+ ---
130
+
131
+ The trained weights are hosted on Hugging Face
132
+ - 🔗 **[Download from Hugging Face Space Files](https://huggingface.co/spaces/)**
133
+
134
+
135
+ ## Technical Report
136
+ A comprehensive report results is included in the repository.
137
+
138
+ **[View Technical Report (PDF)]()**
api/main.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import json
3
+
4
+ import torch
5
+ from fastapi import FastAPI, File, HTTPException, UploadFile
6
+ from fastapi.responses import RedirectResponse
7
+ from PIL import Image
8
+
9
+ from src.infer import predict_disease
10
+
11
+ # Initialize FastAPI with metadata for Swagger
12
+ app = FastAPI(
13
+ title="Plant Disease API",
14
+ description="An API to identify plant diseases from images.",
15
+ version="1.0.0",
16
+ )
17
+
18
+ # Detect device
19
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+
21
+ # Load model and mapping globally
22
+ try:
23
+ model = torch.jit.load("convnext_scripted.pt", map_location=device)
24
+ model.eval()
25
+
26
+ with open("data/label_map.json") as f:
27
+ label_map = json.load(f)
28
+ # Ensure keys are handled correctly (mapping string indices to names)
29
+ idx_to_disease = {int(v): k for k, v in label_map.items()}
30
+ except Exception as e:
31
+ print(f"Error loading model or labels: {e}")
32
+ model = None
33
+
34
+
35
+ @app.get("/", include_in_schema=False)
36
+ async def root():
37
+ """Redirect users to the Swagger UI automatically."""
38
+ return RedirectResponse(url="/docs")
39
+
40
+
41
+ @app.post("/predict", tags=["Inference"])
42
+ async def predict(file: UploadFile = File(...)):
43
+ """
44
+ Upload an image of a plant leaf to identify potential diseases.
45
+ """
46
+ if not model:
47
+ raise HTTPException(status_code=500, detail="Model not loaded on server.")
48
+
49
+ if not file.content_type.startswith("image/"):
50
+ raise HTTPException(status_code=400, detail="File provided is not an image.")
51
+
52
+ try:
53
+ # 1. Read and Preprocess
54
+ img_bytes = await file.read()
55
+ image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
56
+
57
+ # 2. Run Inference
58
+ disease_name = predict_disease(model, image, idx_to_disease, device=device)
59
+
60
+ return {"disease": disease_name}
61
+
62
+ except Exception as e:
63
+ raise HTTPException(status_code=500, detail=str(e))
64
+
65
+
66
+ if __name__ == "__main__":
67
+ import uvicorn
68
+
69
+ uvicorn.run(app, host="0.0.0.0", port=7860)
configs/CNeXv2t.yaml ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ experiment_name: "ConvNeXtv2_t"
2
+ seed: 42
3
+
4
+ data:
5
+ train_dir: "data/train"
6
+ val_dir: "data/val"
7
+ image_size: 384
8
+ batch_size: 16
9
+ num_workers: 2
10
+ pin_memory: true
11
+ weighted_sampling: false
12
+ max_weight: 5
13
+
14
+ model:
15
+ backbone: "convnextv2_tiny.fcmae_ft_in22k_in1k_384" # e.g. "tf_efficientnetv2_s" "convnextv2_tiny.fcmae_ft_in22k_in1k_384" eva02_base_patch14_224
16
+ pretrained: true
17
+ freeze_backbone: false
18
+ freeze_bn: true
19
+ num_classes: null # Inferred automatically from dataset
20
+ dropout: 0.2
21
+ drop_path: 0.2
22
+
23
+ loss:
24
+ name: "ce" # "focal" or "ce"
25
+ gamma: 2.0
26
+ alpha: 0.25
27
+ label_smoothing: 0.1
28
+
29
+ optimizer:
30
+ name: "adamw"
31
+ backbone_lr: 3e-5 #dont matter if layer_decay
32
+ head_lr: 3e-4
33
+ weight_decay: 0.05
34
+ layer_decay: 0.9
35
+
36
+ scheduler:
37
+ name: "cosine_warmup" # "cosine", "step", "plateau"
38
+ warmup_epochs: 2
39
+ min_lr: 1e-6
40
+
41
+ training:
42
+ epochs: 10
43
+ gradient_accumulation_steps: 4
44
+ mixed_precision: true
45
+ clip_grad_norm: 1.0
46
+ early_stopping_patience: 5
47
+ ema:
48
+ enabled: true
49
+ decay: 0.999
50
+ eval_mode: "current" # "current" or "ema"
51
+
52
+ augmentation:
53
+ mixup_alpha: 0.8
54
+ cutmix_alpha: 1.0
55
+ prob: 0.5 # Probability applied per batch
56
+
57
+ logging:
58
+ use_wandb: true
59
+ project_name: "plant-disease-classification"
60
+ checkpoint_dir: "./checkpoints"
configs/EffNv2S_aug.yaml ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ experiment_name: "EffNv2S_aug"
2
+ seed: 42
3
+
4
+ data:
5
+ train_dir: "data/train"
6
+ val_dir: "data/val"
7
+ image_size: 384
8
+ batch_size: 32
9
+ num_workers: 2
10
+ pin_memory: true
11
+ weighted_sampling: false
12
+ max_weight: 5
13
+
14
+ model:
15
+ backbone: "tf_efficientnetv2_s" # e.g. "tf_efficientnetv2_s" "convnextv2_tiny.fcmae_ft_in22k_in1k_384" eva02_base_patch14_224
16
+ pretrained: true
17
+ freeze_backbone: false
18
+ freeze_bn: true
19
+ num_classes: null # Inferred automatically from dataset
20
+ dropout: 0.2
21
+ drop_path: 0.1
22
+
23
+ loss:
24
+ name: "ce" # "focal" or "ce"
25
+ gamma: 2.0
26
+ alpha: 0.25
27
+ label_smoothing: 0.1
28
+
29
+ optimizer:
30
+ name: "adamw"
31
+ backbone_lr: 3e-5 #dont matter if layer_decay
32
+ head_lr: 3e-4
33
+ weight_decay: 1e-2
34
+ layer_decay: 1.0
35
+
36
+ scheduler:
37
+ name: "cosine_warmup" # "cosine", "step", "plateau"
38
+ warmup_epochs: 3
39
+ min_lr: 1e-6
40
+
41
+ training:
42
+ epochs: 15
43
+ gradient_accumulation_steps: 4
44
+ mixed_precision: true
45
+ clip_grad_norm: 1.0
46
+ early_stopping_patience: 5
47
+ ema:
48
+ enabled: true
49
+ decay: 0.999
50
+ eval_mode: "current" # "current" or "ema"
51
+
52
+ augmentation:
53
+ mixup_alpha: 0.8
54
+ cutmix_alpha: 1.0
55
+ prob: 0.5 # Probability applied per batch
56
+
57
+ logging:
58
+ use_wandb: true
59
+ project_name: "plant-disease-classification"
60
+ checkpoint_dir: "./checkpoints"
configs/EffNv2S_baseline.yaml ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ experiment_name: "EffNv2S_baseline"
2
+ seed: 42
3
+
4
+ data:
5
+ train_dir: "data/train"
6
+ val_dir: "data/val"
7
+ image_size: 384
8
+ batch_size: 32
9
+ num_workers: 2
10
+ pin_memory: true
11
+ weighted_sampling: false
12
+ max_weight: 5
13
+
14
+ model:
15
+ backbone: "tf_efficientnetv2_s" # e.g. "tf_efficientnetv2_s" "convnextv2_tiny.fcmae_ft_in22k_in1k_384" eva02_base_patch14_224
16
+ pretrained: true
17
+ freeze_backbone: false
18
+ freeze_bn: true
19
+ num_classes: null # Inferred automatically from dataset
20
+ dropout: 0.2
21
+ drop_path: 0.1
22
+
23
+ loss:
24
+ name: "ce" # "focal" or "ce"
25
+ gamma: 2.0
26
+ alpha: 0.25
27
+ label_smoothing: 0.1
28
+
29
+ optimizer:
30
+ name: "adamw"
31
+ backbone_lr: 3e-5 #dont matter if layer_decay
32
+ head_lr: 3e-4
33
+ weight_decay: 1e-2
34
+ layer_decay: 1.0
35
+
36
+ scheduler:
37
+ name: "cosine_warmup" # "cosine", "step", "plateau"
38
+ warmup_epochs: 3
39
+ min_lr: 1e-6
40
+
41
+ training:
42
+ epochs: 15
43
+ gradient_accumulation_steps: 4
44
+ mixed_precision: true
45
+ clip_grad_norm: 1.0
46
+ early_stopping_patience: 5
47
+ ema:
48
+ enabled: true
49
+ decay: 0.9995
50
+ eval_mode: "current" # "current" or "ema"
51
+
52
+ augmentation:
53
+ mixup_alpha: 0.2
54
+ cutmix_alpha: 0.5
55
+ prob: 0 # Probability applied per batch
56
+
57
+ logging:
58
+ use_wandb: true
59
+ project_name: "plant-disease-classification"
60
+ checkpoint_dir: "./checkpoints"
configs/EffNv2S_head.yaml ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ experiment_name: "EffNv2S_head"
2
+ seed: 42
3
+
4
+ data:
5
+ train_dir: "data/train"
6
+ val_dir: "data/val"
7
+ image_size: 384
8
+ batch_size: 64
9
+ num_workers: 2
10
+ pin_memory: true
11
+ weighted_sampling: false
12
+ max_weight: 5
13
+
14
+ model:
15
+ backbone: "tf_efficientnetv2_s" # e.g. "tf_efficientnetv2_s" "convnextv2_tiny.fcmae_ft_in22k_in1k_384" eva02_base_patch14_224
16
+ pretrained: true
17
+ freeze_backbone: true
18
+ freeze_bn: true
19
+ num_classes: null # Inferred automatically from dataset
20
+ dropout: 0.2
21
+ drop_path: 0.1
22
+
23
+ loss:
24
+ name: "ce" # "focal" or "ce"
25
+ gamma: 2.0
26
+ alpha: 0.25
27
+ label_smoothing: 0.1
28
+
29
+ optimizer:
30
+ name: "adamw"
31
+ backbone_lr: 0 #dont matter if layer_decay
32
+ head_lr: 5e-4
33
+ weight_decay: 1e-2
34
+ layer_decay: 1.0
35
+
36
+ scheduler:
37
+ name: "step" # "cosine", "step", "plateau"
38
+ warmup_epochs: 0
39
+ min_lr: 0.0
40
+
41
+ training:
42
+ epochs: 10
43
+ gradient_accumulation_steps: 1
44
+ mixed_precision: true
45
+ clip_grad_norm: 1.0
46
+ early_stopping_patience: 5
47
+ ema:
48
+ enabled: true
49
+ decay: 0.9995
50
+ eval_mode: "current" # "current" or "ema"
51
+
52
+ augmentation:
53
+ mixup_alpha: 0.2
54
+ cutmix_alpha: 0.5
55
+ prob: 0 # Probability applied per batch
56
+
57
+ logging:
58
+ use_wandb: true
59
+ project_name: "plant-disease-classification"
60
+ checkpoint_dir: "./checkpoints"
configs/config.yaml ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ experiment_name: dinov3_vit_small_plus_baseline
2
+ seed: 42
3
+ data:
4
+ train_dir: data/train
5
+ val_dir: data/val
6
+ image_size: 384
7
+ batch_size: 64
8
+ num_workers: 2
9
+ pin_memory: true
10
+ weighted_sampling: false
11
+ max_weight: 5
12
+ model:
13
+ backbone: vit_small_plus_patch16_dinov3.lvd1689m
14
+ pretrained: true
15
+ freeze_backbone: false
16
+ freeze_bn: false
17
+ num_classes: null
18
+ dropout: 0.1
19
+ drop_path: 0.1
20
+ loss:
21
+ name: ce
22
+ gamma: 2.0
23
+ alpha: 0.25
24
+ label_smoothing: 0.1
25
+ optimizer:
26
+ name: adamw
27
+ backbone_lr: 2.0e-05
28
+ head_lr: 5e-4
29
+ weight_decay: 1e-4
30
+ layer_decay: 1.0
31
+ scheduler:
32
+ name: cosine_warmup
33
+ warmup_epochs: 3
34
+ min_lr: 1e-6
35
+ training:
36
+ epochs: 20
37
+ gradient_accumulation_steps: 2
38
+ mixed_precision: true
39
+ clip_grad_norm: 1.0
40
+ early_stopping_patience: 5
41
+ ema:
42
+ enabled: true
43
+ decay: 0.9994
44
+ eval_mode: current
45
+ augmentation:
46
+ mixup_alpha: 0.8
47
+ cutmix_alpha: 1.0
48
+ prob: 0.5
49
+ logging:
50
+ use_wandb: true
51
+ project_name: plant-disease-classification
52
+ checkpoint_dir: ./checkpoints
configs/dinov3vitS+.yaml ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ experiment_name: "dinov3_vit_small_plus_baseline"
2
+ seed: 42
3
+
4
+ data:
5
+ train_dir: "data/train"
6
+ val_dir: "data/val"
7
+ image_size: 384
8
+ batch_size: 64
9
+ num_workers: 2
10
+ pin_memory: true
11
+ weighted_sampling: false
12
+ max_weight: 5
13
+
14
+ model:
15
+ backbone: "vit_small_plus_patch16_dinov3.lvd1689m" # e.g. "tf_efficientnetv2_s" "convnextv2_tiny.fcmae_ft_in22k_in1k_384" eva02_base_patch14_224
16
+ pretrained: true
17
+ freeze_backbone: true
18
+ freeze_bn: false
19
+ num_classes: null # Inferred automatically from dataset
20
+ dropout: 0.1
21
+ drop_path: 0.1
22
+
23
+ loss:
24
+ name: "ce" # "focal" or "ce"
25
+ gamma: 2.0
26
+ alpha: 0.25
27
+ label_smoothing: 0.1
28
+
29
+ optimizer:
30
+ name: "adamw"
31
+ backbone_lr: 0 #dont matter if layer_decay
32
+ head_lr: 5e-4
33
+ weight_decay: 1e-4
34
+ layer_decay: 1.0
35
+
36
+ scheduler:
37
+ name: "cosine_warmup" # "cosine", "step", "plateau"
38
+ warmup_epochs: 3
39
+ min_lr: 1e-6
40
+
41
+ training:
42
+ epochs: 15
43
+ gradient_accumulation_steps: 2
44
+ mixed_precision: true
45
+ clip_grad_norm: 1.0
46
+ early_stopping_patience: 5
47
+ ema:
48
+ enabled: true
49
+ decay: 0.9994
50
+ eval_mode: "current" # "current" or "ema"
51
+
52
+ augmentation:
53
+ mixup_alpha: 0.2
54
+ cutmix_alpha: 0.5
55
+ prob: 0 # Probability applied per batch
56
+
57
+ logging:
58
+ use_wandb: true
59
+ project_name: "plant-disease-classification"
60
+ checkpoint_dir: "./checkpoints"
data/label_map.json ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alternaria leaf spot": 0,
3
+ "angular leaf spot": 1,
4
+ "anthracnose": 2,
5
+ "bacterial leaf spot": 3,
6
+ "bacterial leaf streak (black chaff)": 4,
7
+ "bacterial wilt": 5,
8
+ "berry blotch": 6,
9
+ "black leaf streak": 7,
10
+ "black rot": 8,
11
+ "blossom end rot": 9,
12
+ "brown rot": 10,
13
+ "brown spot": 11,
14
+ "bunchy top": 12,
15
+ "canker": 13,
16
+ "downy mildew": 14,
17
+ "early blight": 15,
18
+ "frog eye leaf spot": 16,
19
+ "gray leaf spot": 17,
20
+ "greening disease": 18,
21
+ "head scab": 19,
22
+ "late blight": 20,
23
+ "leaf curl": 21,
24
+ "leaf mold": 22,
25
+ "leaf rust": 23,
26
+ "leaf spot": 24,
27
+ "loose smut": 25,
28
+ "mosaic": 26,
29
+ "mosaic virus": 27,
30
+ "northern leaf blight": 28,
31
+ "powdery mildew": 29,
32
+ "rust": 30,
33
+ "scab": 31,
34
+ "septoria blotch": 32,
35
+ "septoria leaf spot": 33,
36
+ "sheath blight": 34,
37
+ "smut": 35,
38
+ "stem rust": 36,
39
+ "stripe rust": 37,
40
+ "tar spot": 38
41
+ }
notebooks/data_analysis.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
notebooks/evaluate.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ torchvision>=0.15.0
3
+ timm>=0.9.0
4
+ omegaconf>=2.3.0
5
+ wandb>=0.15.0
6
+ tqdm>=4.65.0
7
+ torchmetrics>=1.0.0
8
+ pillow>=9.0.0
9
+ numpy
src/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # src module init
src/dataset.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torchvision.transforms.v2 as T
7
+ from PIL import Image
8
+ from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler
9
+
10
+
11
+ class PlantDiseaseDataset(Dataset):
12
+ def __init__(self, root_dir, label_map=None, transform=None):
13
+ """
14
+ Args:
15
+ root_dir (str): Path to the root directory of the split.
16
+ label_map (dict, optional): A dictionary mapping disease names to integers.
17
+ Crucial for consistency across splits.
18
+ transform (callable, optional): PyTorch transforms.
19
+ """
20
+ self.root_dir = Path(root_dir)
21
+ self.transform = transform
22
+
23
+ self.image_paths = []
24
+ self.labels = []
25
+ self.plant_labels = []
26
+
27
+ self.plants = [
28
+ "apple",
29
+ "banana",
30
+ "bean",
31
+ "bell pepper",
32
+ "blueberry",
33
+ "basil",
34
+ "broccoli",
35
+ "cabbage",
36
+ "cauliflower",
37
+ "celery",
38
+ "cherry",
39
+ "citrus",
40
+ "coffee",
41
+ "corn",
42
+ "cucumber",
43
+ "garlic",
44
+ "ginger",
45
+ "grape",
46
+ "lettuce",
47
+ "maple",
48
+ "peach",
49
+ "plum",
50
+ "potato",
51
+ "raspberry",
52
+ "rice",
53
+ "soybean",
54
+ "squash",
55
+ "strawberry",
56
+ "tobacco",
57
+ "tomato",
58
+ "wheat",
59
+ "zucchini",
60
+ ]
61
+ self.plants.sort(key=len, reverse=True)
62
+
63
+ if not self.root_dir.exists():
64
+ return
65
+
66
+ if label_map is None:
67
+ self.disease_to_idx = self._build_label_map()
68
+ else:
69
+ self.disease_to_idx = label_map
70
+
71
+ for folder_name in sorted([d for d in self.root_dir.iterdir() if d.is_dir()]):
72
+ disease, plant = self._split_plant_disease(folder_name)
73
+
74
+ if disease not in self.disease_to_idx:
75
+ print(
76
+ f"WARNING: Skipping '{folder_name.name}': Disease '{disease}' not found in label_map"
77
+ )
78
+ continue
79
+
80
+ disease_idx = self.disease_to_idx[disease]
81
+
82
+ for img_path in folder_name.glob("**/*"):
83
+ if img_path.is_file() and img_path.suffix.lower() in [
84
+ ".jpg",
85
+ ".jpeg",
86
+ ".png",
87
+ ".webp",
88
+ ]:
89
+ self.image_paths.append(str(img_path))
90
+ self.labels.append(disease_idx)
91
+ self.plant_labels.append(plant)
92
+
93
+ self.classes = list(self.disease_to_idx.keys())
94
+
95
+ def _build_label_map(self):
96
+ all_diseases = set()
97
+
98
+ for folder in sorted([d for d in self.root_dir.iterdir() if d.is_dir()]):
99
+ folder_name = folder.name.lower()
100
+ for plant in self.plants:
101
+ if folder_name.startswith(plant):
102
+ disease_name = folder_name[len(plant) :].strip()
103
+ all_diseases.add(disease_name)
104
+ break
105
+
106
+ return {disease: i for i, disease in enumerate(sorted(list(all_diseases)))}
107
+
108
+ def _split_plant_disease(self, folder):
109
+ for plant in self.plants:
110
+ folder_name = folder.name.lower()
111
+ if folder_name.startswith(plant):
112
+ disease = folder_name[len(plant) :].strip()
113
+ return disease, plant
114
+
115
+ return None, None
116
+
117
+ def __len__(self):
118
+ return len(self.image_paths)
119
+
120
+ def __getitem__(self, idx):
121
+ img_path = self.image_paths[idx]
122
+ label = self.labels[idx]
123
+
124
+ try:
125
+ image = Image.open(img_path).convert("RGB")
126
+ except Exception as e:
127
+ print(f"Error loading {img_path}: {e}")
128
+ return None, None
129
+
130
+ if self.transform:
131
+ image = self.transform(image)
132
+
133
+ return image, label
134
+
135
+
136
+ def get_transforms(image_size=384, is_train=True):
137
+ if is_train:
138
+ return T.Compose(
139
+ [
140
+ T.RandomResizedCrop(image_size, scale=(0.7, 1.0), antialias=True),
141
+ T.RandomHorizontalFlip(),
142
+ T.TrivialAugmentWide(),
143
+ T.ToImage(),
144
+ T.ToDtype(torch.float32, scale=True),
145
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
146
+ ]
147
+ )
148
+ else:
149
+ return T.Compose(
150
+ [
151
+ T.Resize((image_size, image_size), antialias=True),
152
+ T.ToImage(),
153
+ T.ToDtype(torch.float32, scale=True),
154
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
155
+ ]
156
+ )
157
+
158
+
159
+ def build_weighted_sampler(dataset, max_weight=10.0):
160
+ if not hasattr(dataset, "labels") or not hasattr(dataset, "plant_labels"):
161
+ raise ValueError("Dataset must have 'labels' and 'plant_labels'")
162
+
163
+ if len(dataset) == 0:
164
+ return None
165
+
166
+ disease = torch.tensor(dataset.labels, dtype=torch.long)
167
+ _, plant_indices = np.unique(dataset.plant_labels, return_inverse=True)
168
+ plant = torch.tensor(plant_indices, dtype=torch.long)
169
+
170
+ disease_counts = torch.bincount(disease)
171
+
172
+ pairs = torch.stack([disease, plant], dim=1)
173
+ _, group_id = torch.unique(pairs, return_inverse=True, dim=0)
174
+ group_counts = torch.bincount(group_id)
175
+
176
+ d_count = disease_counts[disease]
177
+ g_count = group_counts[group_id]
178
+
179
+ weights = 1.0 / torch.sqrt(d_count.float() * g_count.float())
180
+
181
+ if max_weight:
182
+ weights = torch.clamp(weights, max=max_weight)
183
+
184
+ return WeightedRandomSampler(
185
+ weights=weights,
186
+ num_samples=len(dataset),
187
+ replacement=True,
188
+ )
189
+
190
+
191
+ def get_dataloaders(config):
192
+ train_dir = Path(config.data.train_dir)
193
+ val_dir = Path(config.data.val_dir)
194
+
195
+ train_dir.mkdir(parents=True, exist_ok=True)
196
+ val_dir.mkdir(parents=True, exist_ok=True)
197
+
198
+ # load existing if dont exist dataset wil build from training automatically
199
+ label_map_path = train_dir.parent / "label_map.json"
200
+ if label_map_path.exists():
201
+ with open(label_map_path) as f:
202
+ label_map = json.load(f)
203
+ else:
204
+ label_map = None
205
+
206
+ # create datasets with consistent label_map
207
+ train_dataset = PlantDiseaseDataset(
208
+ config.data.train_dir,
209
+ label_map=label_map,
210
+ transform=get_transforms(config.data.image_size, is_train=True),
211
+ )
212
+ val_dataset = PlantDiseaseDataset(
213
+ config.data.val_dir,
214
+ label_map=train_dataset.disease_to_idx,
215
+ transform=get_transforms(config.data.image_size, is_train=False),
216
+ )
217
+
218
+ # save label_map for future
219
+ with open(label_map_path, "w") as f:
220
+ json.dump(train_dataset.disease_to_idx, f, indent=2)
221
+
222
+ if len(train_dataset) == 0:
223
+ print("Warning: No train data found. Dataloader might fail.")
224
+
225
+ num_classes = len(train_dataset.classes) if len(train_dataset) > 0 else 0
226
+
227
+ train_sampler = None
228
+ if config.data.weighted_sampling:
229
+ train_sampler = build_weighted_sampler(
230
+ train_dataset, max_weight=config.data.max_weight
231
+ )
232
+
233
+ train_loader = DataLoader(
234
+ train_dataset,
235
+ batch_size=config.data.batch_size,
236
+ sampler=train_sampler,
237
+ shuffle=(train_sampler is None),
238
+ num_workers=config.data.num_workers,
239
+ pin_memory=config.data.pin_memory,
240
+ drop_last=True if len(train_dataset) > config.data.batch_size else False,
241
+ )
242
+
243
+ val_loader = DataLoader(
244
+ val_dataset,
245
+ batch_size=config.data.batch_size,
246
+ shuffle=False,
247
+ num_workers=config.data.num_workers,
248
+ pin_memory=config.data.pin_memory,
249
+ )
250
+
251
+ return train_loader, val_loader, num_classes
src/infer.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import io
3
+ import json
4
+ from pathlib import Path
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torchvision.transforms.v2 as T
9
+ from PIL import Image
10
+ from sklearn.metrics import accuracy_score, average_precision_score
11
+
12
+ from dataset import get_transforms
13
+
14
+
15
+ def get_tta_transforms(image_size):
16
+ return [
17
+ T.Compose(
18
+ [
19
+ T.Resize((image_size, image_size), antialias=True),
20
+ T.ToImage(),
21
+ T.ToDtype(torch.float32, scale=True),
22
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
23
+ ]
24
+ ),
25
+ T.Compose(
26
+ [
27
+ T.Resize((image_size, image_size), antialias=True),
28
+ T.RandomHorizontalFlip(p=1.0),
29
+ T.ToImage(),
30
+ T.ToDtype(torch.float32, scale=True),
31
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
32
+ ]
33
+ ),
34
+ T.Compose(
35
+ [
36
+ T.Resize(int(image_size * 1.1), antialias=True),
37
+ T.CenterCrop(image_size),
38
+ T.ToImage(),
39
+ T.ToDtype(torch.float32, scale=True),
40
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
41
+ ]
42
+ ),
43
+ ]
44
+
45
+
46
+ def evaluate(model, val_loader, device=None, use_tta=False, image_size=384):
47
+ if device is None:
48
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
49
+
50
+ model = model.to(device)
51
+ model.eval()
52
+
53
+ if use_tta:
54
+ tta_transforms = get_tta_transforms(image_size)
55
+
56
+ all_probs = []
57
+ all_labels = []
58
+
59
+ with torch.inference_mode():
60
+ for images, labels in val_loader:
61
+ images = images.to(device)
62
+ labels = labels.to(device)
63
+
64
+ if use_tta:
65
+ tta_batches = []
66
+
67
+ for transform in tta_transforms:
68
+ augmented = torch.stack([transform(img.cpu()) for img in images])
69
+ tta_batches.append(augmented)
70
+
71
+ tta_batches = torch.stack(tta_batches).to(device)
72
+
73
+ outputs = []
74
+ for tta_batch in tta_batches:
75
+ out = model(tta_batch) # [batch, num_classes]
76
+ outputs.append(out)
77
+
78
+ outputs = torch.stack(outputs).mean(dim=0)
79
+
80
+ else:
81
+ outputs = model(images)
82
+
83
+ probs = torch.softmax(outputs, dim=1)
84
+
85
+ all_probs.append(probs.cpu())
86
+ all_labels.append(labels.cpu())
87
+
88
+ all_probs = torch.cat(all_probs).numpy()
89
+ all_labels = torch.cat(all_labels).numpy()
90
+
91
+ preds = np.argmax(all_probs, axis=1)
92
+ acc = accuracy_score(all_labels, preds)
93
+
94
+ num_classes = all_probs.shape[1]
95
+ y_true_bin = np.zeros((len(all_labels), num_classes))
96
+ y_true_bin[np.arange(len(all_labels)), all_labels] = 1
97
+
98
+ per_class_ap = []
99
+ for i in range(num_classes):
100
+ if y_true_bin[:, i].sum() > 0:
101
+ ap = average_precision_score(y_true_bin[:, i], all_probs[:, i])
102
+ per_class_ap.append(ap)
103
+
104
+ mAP = np.mean(per_class_ap)
105
+
106
+ return acc, mAP, all_probs, all_labels
107
+
108
+
109
+ def predict_disease(
110
+ model, image, idx_to_disease, image_size=384, use_tta=False, device=None
111
+ ):
112
+ if device is None:
113
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
114
+
115
+ model = model.to(device)
116
+ model.eval()
117
+
118
+ if use_tta:
119
+ transforms = get_tta_transforms(image_size)
120
+ tensors = [transform(image).unsqueeze(0) for transform in transforms]
121
+ batch = torch.cat(tensors, dim=0).to(device)
122
+
123
+ with torch.inference_mode():
124
+ outputs = model(batch)
125
+ output = outputs.mean(dim=0, keepdim=True)
126
+
127
+ else:
128
+ transform = get_transforms(image_size, is_train=False)
129
+ tensor = transform(image).unsqueeze(0).to(device)
130
+
131
+ with torch.inference_mode():
132
+ output = model(tensor)
133
+
134
+ probs = output.softmax(dim=1)
135
+ disease_name = idx_to_disease[probs.argmax(dim=1).item()]
136
+
137
+ return disease_name
138
+
139
+
140
+ if __name__ == "__main__":
141
+ parser = argparse.ArgumentParser(
142
+ description="Run inference on a plant disease image"
143
+ )
144
+ parser.add_argument("--image_path", type=str, help="Path to input image")
145
+ parser.add_argument(
146
+ "--image_size", type=str, default=384, help="Size of input image"
147
+ )
148
+ parser.add_argument(
149
+ "--checkpoint", type=str, default=None, help="Path to checkpoint "
150
+ )
151
+ parser.add_argument("--tta", action="store_true", help="Use test time augmentation")
152
+ args = parser.parse_args()
153
+
154
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
155
+
156
+ model = torch.jit.load(args.checkpoint).to(device)
157
+ model.eval()
158
+ print(args.tta)
159
+ # load label map
160
+ data_dir = Path("data")
161
+ label_map_path = data_dir / "label_map.json"
162
+ with open(label_map_path) as f:
163
+ label_map = json.load(f)
164
+ idx_to_disease = {int(v): k for k, v in label_map.items()}
165
+
166
+ image = Image.open(args.image_path).convert("RGB")
167
+
168
+ result = predict_disease(
169
+ model,
170
+ image,
171
+ image_size=args.image_size,
172
+ idx_to_disease=idx_to_disease,
173
+ use_tta=args.tta,
174
+ )
175
+ print(f"Disease: {result}")
src/loss.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from timm.loss import SoftTargetCrossEntropy
5
+
6
+
7
+ class FocalLoss(nn.Module):
8
+ def __init__(self, alpha=0.25, gamma=2.0, reduction="mean", label_smoothing=0.0):
9
+ super().__init__()
10
+ self.alpha = alpha
11
+ self.gamma = gamma
12
+ self.reduction = reduction
13
+ self.label_smoothing = label_smoothing
14
+
15
+ def forward(self, inputs, targets):
16
+ """
17
+ inputs: logits [B, C]
18
+ targets: labels [B] or soft mixup labels [B, C]
19
+ """
20
+ if targets.ndim == inputs.ndim:
21
+ # targets are soft labels from MixUp/CutMix
22
+ ce_loss = F.cross_entropy(
23
+ inputs, targets, reduction="none", label_smoothing=self.label_smoothing
24
+ )
25
+ # for focal weighting when using mixup, pt is e^(-ce_loss)
26
+ pt = torch.exp(-ce_loss)
27
+ else:
28
+ ce_loss = F.cross_entropy(
29
+ inputs, targets, reduction="none", label_smoothing=self.label_smoothing
30
+ )
31
+ pt = torch.exp(-ce_loss)
32
+
33
+ focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
34
+
35
+ if self.reduction == "mean":
36
+ return focal_loss.mean()
37
+ elif self.reduction == "sum":
38
+ return focal_loss.sum()
39
+ return focal_loss
40
+
41
+
42
+ def get_criterion(config):
43
+ if config.loss.name == "focal":
44
+ return FocalLoss(
45
+ gamma=config.loss.gamma,
46
+ alpha=config.loss.alpha,
47
+ label_smoothing=config.loss.label_smoothing,
48
+ )
49
+ else:
50
+ if config.augmentation.prob > 0:
51
+ return SoftTargetCrossEntropy()
52
+ return nn.CrossEntropyLoss(label_smoothing=config.loss.label_smoothing)
src/metrics.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchmetrics.classification import MulticlassAccuracy, MulticlassAveragePrecision
3
+
4
+
5
+ class MetricTracker:
6
+ def __init__(self, num_classes, device):
7
+ self.num_classes = num_classes
8
+ self.device = device
9
+
10
+ self.map_metric = MulticlassAveragePrecision(num_classes=num_classes).to(device)
11
+ self.acc_metric = MulticlassAccuracy(num_classes=num_classes).to(device)
12
+
13
+ self.reset()
14
+
15
+ def reset(self):
16
+ self.map_metric.reset()
17
+ self.acc_metric.reset()
18
+ self.loss_sum = 0
19
+ self.count = 0
20
+
21
+ def update(self, preds, targets, loss=None, skip_metrics=False):
22
+ """
23
+ preds: logits [B, C]
24
+ targets: [B] or soft labels [B, C]
25
+ skip_metrics: If True, only loss is tracked. Use for MixUp/CutMix batches.
26
+ """
27
+ if targets.ndim > 1:
28
+ hard_targets = targets.argmax(dim=1)
29
+ else:
30
+ hard_targets = targets
31
+ if not skip_metrics:
32
+ self.map_metric.update(preds, hard_targets)
33
+ self.acc_metric.update(preds, hard_targets)
34
+
35
+ if loss is not None:
36
+ self.loss_sum += loss * preds.size(0)
37
+ self.count += preds.size(0)
38
+
39
+ def compute(self):
40
+ mAP = self.map_metric.compute().item()
41
+ acc = self.acc_metric.compute().item()
42
+ avg_loss = self.loss_sum / max(self.count, 1)
43
+ return {"mAP": mAP, "accuracy": acc, "loss": avg_loss}
src/models.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import timm
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+
6
+ class PlantDiseaseModel(nn.Module):
7
+ def __init__(self, config, num_classes):
8
+ super().__init__()
9
+ self.backbone_name = config.model.backbone
10
+
11
+ self.model = timm.create_model(
12
+ self.backbone_name,
13
+ pretrained=config.model.pretrained,
14
+ num_classes=num_classes,
15
+ drop_rate=config.model.dropout,
16
+ drop_path_rate=config.model.drop_path,
17
+ )
18
+
19
+ if config.model.freeze_backbone:
20
+ self._freeze_backbone()
21
+ if config.model.freeze_bn:
22
+ self.freeze_bn()
23
+
24
+ def _freeze_backbone(self):
25
+ for param in self.model.parameters():
26
+ param.requires_grad = False
27
+
28
+ if hasattr(self.model, "get_classifier"):
29
+ classifier = self.model.get_classifier()
30
+ for param in classifier.parameters():
31
+ param.requires_grad = True
32
+ else:
33
+ for name, param in self.model.named_parameters():
34
+ if "head" in name or "classifier" in name:
35
+ param.requires_grad = True
36
+
37
+ def freeze_bn(self):
38
+ for module in self.model.modules():
39
+ if isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
40
+ module.eval()
41
+
42
+ if module.weight is not None:
43
+ module.weight.requires_grad = False
44
+ if module.bias is not None:
45
+ module.bias.requires_grad = False
46
+
47
+ def forward(self, x):
48
+ return self.model(x)
49
+
50
+
51
+ def get_param_groups(model, base_lr, head_lr, weight_decay):
52
+ if hasattr(model.model, "get_classifier"):
53
+ head = model.model.get_classifier()
54
+ head_params = list(head.parameters())
55
+ head_param_ids = set(id(p) for p in head_params)
56
+ else:
57
+ # fallback
58
+ head_params = []
59
+ for name, p in model.named_parameters():
60
+ if any(k in name for k in ["head", "classifier"]):
61
+ head_params.append(p)
62
+ head_param_ids = set(id(p) for p in head_params)
63
+
64
+ head_params = [p for p in head_params if p.requires_grad]
65
+
66
+ backbone_params = [
67
+ p for p in model.parameters() if id(p) not in head_param_ids and p.requires_grad
68
+ ]
69
+ return [
70
+ {"params": backbone_params, "lr": base_lr, "weight_decay": weight_decay},
71
+ {"params": head_params, "lr": head_lr, "weight_decay": weight_decay},
72
+ ]
src/trainer.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import wandb
3
+ from omegaconf import OmegaConf
4
+ from timm.utils import ModelEmaV2
5
+ from torch import nn
6
+ from torch.amp import GradScaler, autocast
7
+ from torchvision.transforms import v2
8
+ from tqdm import tqdm
9
+
10
+ from .metrics import MetricTracker
11
+ from .utils import EarlyStopping, save_checkpoint
12
+
13
+
14
+ class Trainer:
15
+ def __init__(
16
+ self,
17
+ model,
18
+ train_loader,
19
+ val_loader,
20
+ criterion,
21
+ optimizer,
22
+ scheduler,
23
+ config,
24
+ device,
25
+ ):
26
+ self.model = model
27
+ self.train_loader = train_loader
28
+ self.val_loader = val_loader
29
+ self.criterion = criterion
30
+ self.optimizer = optimizer
31
+ self.scheduler = scheduler
32
+ self.config = config
33
+ self.device = device
34
+
35
+ self.early_stopping = EarlyStopping(
36
+ patience=config.training.early_stopping_patience, mode="max"
37
+ )
38
+
39
+ self.scaler = GradScaler(device.type, enabled=config.training.mixed_precision)
40
+
41
+ self.use_ema = (
42
+ getattr(config.training, "ema", None) and config.training.ema.enabled
43
+ )
44
+ if self.use_ema:
45
+ ema_decay = getattr(config.training.ema, "decay", 0.9999)
46
+ self.model_ema = ModelEmaV2(self.model, decay=ema_decay, device=device)
47
+ else:
48
+ self.model_ema = None
49
+
50
+ self.num_classes = config.model.num_classes
51
+
52
+ self.use_mixup = False
53
+ if config.augmentation.prob > 0:
54
+ self.use_mixup = True
55
+ cutmix = v2.CutMix(
56
+ alpha=config.augmentation.cutmix_alpha, num_classes=self.num_classes
57
+ )
58
+ mixup = v2.MixUp(
59
+ alpha=config.augmentation.mixup_alpha, num_classes=self.num_classes
60
+ )
61
+ self.cutmix_or_mixup = v2.RandomChoice([cutmix, mixup])
62
+
63
+ self.train_metrics = MetricTracker(num_classes=self.num_classes, device=device)
64
+ self.val_metrics = MetricTracker(num_classes=self.num_classes, device=device)
65
+ if self.use_ema:
66
+ self.val_ema_metrics = MetricTracker(
67
+ num_classes=self.num_classes, device=device
68
+ )
69
+
70
+ def train_one_epoch(self, epoch):
71
+ self.model.train()
72
+ if self.config.model.freeze_bn:
73
+ for module in self.model.modules():
74
+ if isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
75
+ module.eval()
76
+
77
+ self.train_metrics.reset()
78
+
79
+ pbar = tqdm(self.train_loader, desc=f"Epoch {epoch} [Train]")
80
+ for batch_idx, (images, targets) in enumerate(pbar):
81
+ images, targets = images.to(self.device), targets.to(self.device)
82
+ is_mixed = False
83
+
84
+ # apply MixUp or CutMix
85
+ if self.use_mixup and torch.rand(1).item() < self.config.augmentation.prob:
86
+ images, targets = self.cutmix_or_mixup(images, targets)
87
+ is_mixed = True
88
+ if targets.ndim == 1:
89
+ targets = torch.nn.functional.one_hot(
90
+ targets, num_classes=self.num_classes
91
+ ).float()
92
+ with autocast(
93
+ device_type=self.device.type,
94
+ enabled=self.config.training.mixed_precision,
95
+ ):
96
+ outputs = self.model(images)
97
+ loss = self.criterion(outputs, targets)
98
+ # gradient accumulation normalizer
99
+ loss = loss / self.config.training.gradient_accumulation_steps
100
+
101
+ self.scaler.scale(loss).backward()
102
+
103
+ if (batch_idx + 1) % self.config.training.gradient_accumulation_steps == 0:
104
+ if self.config.training.clip_grad_norm > 0:
105
+ self.scaler.unscale_(self.optimizer)
106
+ torch.nn.utils.clip_grad_norm_(
107
+ self.model.parameters(), self.config.training.clip_grad_norm
108
+ )
109
+
110
+ self.scaler.step(self.optimizer)
111
+ self.scaler.update()
112
+ self.optimizer.zero_grad()
113
+
114
+ if self.config.scheduler.name == "cosine_warmup":
115
+ self.scheduler.step()
116
+
117
+ if self.use_ema:
118
+ self.model_ema.update(self.model)
119
+
120
+ batch_loss = loss.item() * self.config.training.gradient_accumulation_steps
121
+ self.train_metrics.update(
122
+ outputs.detach(),
123
+ targets.detach(),
124
+ loss=batch_loss,
125
+ skip_metrics=is_mixed,
126
+ )
127
+
128
+ pbar.set_postfix({"loss": f"{batch_loss:.4f}"})
129
+
130
+ if self.config.logging.use_wandb:
131
+ wandb.log({"train/batch_loss": batch_loss})
132
+
133
+ if (batch_idx + 1) % self.config.training.gradient_accumulation_steps != 0:
134
+ if self.config.training.clip_grad_norm > 0:
135
+ self.scaler.unscale_(self.optimizer)
136
+ torch.nn.utils.clip_grad_norm_(
137
+ self.model.parameters(), self.config.training.clip_grad_norm
138
+ )
139
+
140
+ self.scaler.step(self.optimizer)
141
+ self.scaler.update()
142
+ self.optimizer.zero_grad()
143
+
144
+ if self.config.scheduler.name == "cosine_warmup":
145
+ self.scheduler.step()
146
+
147
+ if self.use_ema:
148
+ self.model_ema.update(self.model)
149
+
150
+ metrics = self.train_metrics.compute()
151
+
152
+ # Step schedulers that step per epoch
153
+ if self.config.scheduler.name == "step":
154
+ self.scheduler.step()
155
+ elif self.config.scheduler.name == "cosine":
156
+ self.scheduler.step()
157
+
158
+ return metrics
159
+
160
+ def validate(self, epoch):
161
+ self.model.eval()
162
+ self.val_metrics.reset()
163
+
164
+ if self.use_ema:
165
+ self.model_ema.module.eval()
166
+ self.val_ema_metrics.reset()
167
+
168
+ pbar = tqdm(self.val_loader, desc=f"Epoch {epoch} [Val]")
169
+ with torch.no_grad():
170
+ for images, targets in pbar:
171
+ images, targets = images.to(self.device), targets.to(self.device)
172
+
173
+ if targets.ndim == 1:
174
+ targets = torch.nn.functional.one_hot(
175
+ targets, num_classes=self.num_classes
176
+ ).float()
177
+
178
+ with autocast(
179
+ device_type=self.device.type,
180
+ enabled=self.config.training.mixed_precision,
181
+ ):
182
+ outputs = self.model(images)
183
+ loss = self.criterion(outputs, targets)
184
+
185
+ if self.use_ema:
186
+ ema_outputs = self.model_ema.module(images)
187
+ ema_loss = self.criterion(ema_outputs, targets)
188
+
189
+ self.val_metrics.update(
190
+ outputs.detach(), targets.detach(), loss=loss.detach()
191
+ )
192
+ if self.use_ema:
193
+ self.val_ema_metrics.update(
194
+ ema_outputs.detach(), targets.detach(), loss=ema_loss.detach()
195
+ )
196
+ pbar.set_postfix(
197
+ {
198
+ "loss": f"{loss.item():.4f}",
199
+ "ema_loss": f"{ema_loss.item():.4f}",
200
+ }
201
+ )
202
+ else:
203
+ pbar.set_postfix({"loss": f"{loss.item():.4f}"})
204
+
205
+ metrics = {"current": self.val_metrics.compute()}
206
+ if self.use_ema:
207
+ metrics["ema"] = self.val_ema_metrics.compute()
208
+
209
+ primary_map = metrics[self.config.training.ema.eval_mode]["mAP"]
210
+
211
+ if self.config.scheduler.name == "plateau":
212
+ self.scheduler.step(primary_map)
213
+
214
+ return metrics
215
+
216
+ def fit(self, start_epoch=1):
217
+ best_map = 0.0
218
+
219
+ for epoch in range(start_epoch, self.config.training.epochs + 1):
220
+ train_metrics = self.train_one_epoch(epoch)
221
+ val_metrics = self.validate(epoch)
222
+
223
+ lrs = [pg["lr"] for pg in self.optimizer.param_groups]
224
+
225
+ log_dict = {
226
+ "train/loss": train_metrics["loss"],
227
+ "train/mAP": train_metrics["mAP"],
228
+ "train/accuracy": train_metrics["accuracy"],
229
+ "lr/backbone": lrs[0],
230
+ "lr/head": lrs[1],
231
+ "epoch": epoch,
232
+ }
233
+
234
+ if self.use_ema:
235
+ log_dict.update(
236
+ {
237
+ "val/loss": val_metrics["current"]["loss"],
238
+ "val/mAP": val_metrics["current"]["mAP"],
239
+ "val/accuracy": val_metrics["current"]["accuracy"],
240
+ "val/ema_loss": val_metrics["ema"]["loss"],
241
+ "val/ema_mAP": val_metrics["ema"]["mAP"],
242
+ "val/ema_accuracy": val_metrics["ema"]["accuracy"],
243
+ }
244
+ )
245
+ else:
246
+ log_dict.update(
247
+ {
248
+ "val/loss": val_metrics["current"]["loss"],
249
+ "val/mAP": val_metrics["current"]["mAP"],
250
+ "val/accuracy": val_metrics["current"]["accuracy"],
251
+ }
252
+ )
253
+
254
+ if self.config.logging.use_wandb:
255
+ wandb.log(log_dict)
256
+
257
+ print(f"\nEpoch {epoch} Summary:")
258
+ print(f"LR: Backbone: {lrs[0]:.2e} | Head: {lrs[1]:.2e}")
259
+ print(
260
+ f"Train - Loss: {train_metrics['loss']:.4f}, mAP: {train_metrics['mAP']:.4f}, Acc: {train_metrics['accuracy']:.4f}"
261
+ )
262
+ if self.use_ema:
263
+ print(
264
+ f"Val (Current) - Loss: {val_metrics['current']['loss']:.4f}, mAP: {val_metrics['current']['mAP']:.4f}, Acc: {val_metrics['current']['accuracy']:.4f}"
265
+ )
266
+ print(
267
+ f"Val (EMA) - Loss: {val_metrics['ema']['loss']:.4f}, mAP: {val_metrics['ema']['mAP']:.4f}, Acc: {val_metrics['ema']['accuracy']:.4f}"
268
+ )
269
+ else:
270
+ print(
271
+ f"Val - Loss: {val_metrics['current']['loss']:.4f}, mAP: {val_metrics['current']['mAP']:.4f}, Acc: {val_metrics['current']['accuracy']:.4f}"
272
+ )
273
+
274
+ primary_map = val_metrics[self.config.training.ema.eval_mode]["mAP"]
275
+ is_best = self.early_stopping(primary_map)
276
+
277
+ if is_best:
278
+ best_map = primary_map
279
+ print(f"Epoch {epoch} is the new best model. mAP: {best_map:.4f}")
280
+
281
+ # Checkpointing
282
+ state = {
283
+ "epoch": epoch,
284
+ "state_dict": self.model.state_dict(),
285
+ "state_dict_ema": self.model_ema.module.state_dict()
286
+ if self.use_ema
287
+ else None,
288
+ "optimizer": self.optimizer.state_dict(),
289
+ "scheduler": self.scheduler.state_dict() if self.scheduler else None,
290
+ "scaler": self.scaler.state_dict(),
291
+ "early_stopping": {
292
+ "best_score": self.early_stopping.best_score,
293
+ "counter": self.early_stopping.counter,
294
+ "early_stop": self.early_stopping.early_stop,
295
+ },
296
+ "rng_states": {
297
+ "torch": torch.get_rng_state(),
298
+ "cuda": torch.cuda.get_rng_state_all()
299
+ if torch.cuda.is_available()
300
+ else None,
301
+ },
302
+ "val_mAP": primary_map,
303
+ "config": OmegaConf.to_yaml(self.config),
304
+ "wandb_run_id": wandb.run.id if wandb.run is not None else None,
305
+ }
306
+ save_checkpoint(state, is_best, self.config.logging.checkpoint_dir)
307
+
308
+ if self.early_stopping.early_stop:
309
+ print(f"Early stopping triggered at epoch {epoch}")
310
+ break
311
+
312
+ print("Training complete!")
src/utils.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ import random
4
+ import shutil
5
+ from pathlib import Path
6
+
7
+ import numpy as np
8
+ import torch
9
+ from omegaconf import OmegaConf
10
+ from PIL import Image
11
+
12
+
13
+ class EarlyStopping:
14
+ def __init__(self, patience=7, mode="max"):
15
+ self.patience = patience
16
+ self.mode = mode
17
+ self.counter = 0
18
+ self.best_score = None
19
+ self.early_stop = False
20
+
21
+ def __call__(self, metric_value):
22
+ score = -metric_value if self.mode == "min" else metric_value
23
+
24
+ if self.best_score is None:
25
+ self.best_score = score
26
+ return True
27
+ elif score < self.best_score:
28
+ self.counter += 1
29
+ if self.counter >= self.patience:
30
+ self.early_stop = True
31
+ return False
32
+ else:
33
+ self.best_score = score
34
+ self.counter = 0
35
+ return True
36
+
37
+
38
+ class CosineAnnealingWarmupLR(torch.optim.lr_scheduler._LRScheduler):
39
+ def __init__(self, optimizer, warmup_steps, total_steps, min_lr=0, last_epoch=-1):
40
+ self.warmup_steps = warmup_steps
41
+ self.total_steps = total_steps
42
+ self.min_lr = min_lr
43
+
44
+ self.min_lr_ratios = []
45
+ for group in optimizer.param_groups:
46
+ ratio = min_lr / max(group["lr"], 1e-12)
47
+ self.min_lr_ratios.append(ratio)
48
+
49
+ super().__init__(optimizer, last_epoch)
50
+
51
+ def get_lr(self):
52
+ curr_step = self.last_epoch
53
+
54
+ # linear warmup phase
55
+ if curr_step < self.warmup_steps:
56
+ scale = curr_step / max(1, self.warmup_steps)
57
+ return [base_lr * scale for base_lr in self.base_lrs]
58
+
59
+ # cosine annealing phase
60
+ progress = (curr_step - self.warmup_steps) / max(
61
+ 1, self.total_steps - self.warmup_steps
62
+ )
63
+ progress = min(1.0, max(0.0, progress))
64
+ cosine = 0.5 * (1 + math.cos(math.pi * progress))
65
+
66
+ return [
67
+ base_lr * (ratio + (1 - ratio) * cosine)
68
+ for base_lr, ratio in zip(self.base_lrs, self.min_lr_ratios)
69
+ ]
70
+
71
+
72
+ def set_seed(seed=42, deterministic=False):
73
+ random.seed(seed)
74
+ np.random.seed(seed)
75
+ torch.manual_seed(seed)
76
+ torch.cuda.manual_seed_all(seed)
77
+ if deterministic:
78
+ torch.backends.cudnn.deterministic = True
79
+ torch.backends.cudnn.benchmark = False
80
+
81
+
82
+ def load_config(config_path):
83
+ return OmegaConf.load(config_path)
84
+
85
+
86
+ def save_checkpoint(state, is_best, checkpoint_dir, filename="last.pt"):
87
+ os.makedirs(checkpoint_dir, exist_ok=True)
88
+ epoch = state["epoch"]
89
+ filename = f"checkpoint_epoch_{epoch}.pt"
90
+ filepath = os.path.join(checkpoint_dir, filename)
91
+ torch.save(state, filepath)
92
+
93
+ last_path = os.path.join(checkpoint_dir, "last.pt")
94
+ shutil.copyfile(filepath, last_path)
95
+
96
+ if is_best:
97
+ best_path = os.path.join(checkpoint_dir, "best.pt")
98
+ shutil.copyfile(filepath, best_path)
99
+
100
+
101
+ def check_dataset(data_dir):
102
+ data_path = Path(data_dir)
103
+ corrupt_files = []
104
+
105
+ print(f"Checking images in {data_dir}...")
106
+
107
+ for img_path in data_path.glob("**/*"):
108
+ if img_path.suffix.lower() in [".jpg", ".jpeg", ".png", ".webp"]:
109
+ try:
110
+ with Image.open(img_path) as img:
111
+ img.verify()
112
+
113
+ except Exception as e:
114
+ print(f"CORRUPT: {img_path} | Error: {e}")
115
+ corrupt_files.append(img_path)
116
+
117
+ if corrupt_files:
118
+ print(f"\nFound {len(corrupt_files)} corrupted files.")
119
+ else:
120
+ print("Dataset is clean")
train.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import math
3
+ import os
4
+
5
+ import torch
6
+ import wandb
7
+ from omegaconf import OmegaConf
8
+ from timm.optim import create_optimizer_v2
9
+ from torch.optim.lr_scheduler import CosineAnnealingLR, ReduceLROnPlateau, StepLR
10
+
11
+ from src.dataset import get_dataloaders
12
+ from src.loss import get_criterion
13
+ from src.models import PlantDiseaseModel, get_param_groups
14
+ from src.trainer import Trainer
15
+ from src.utils import CosineAnnealingWarmupLR, load_config, set_seed
16
+
17
+
18
+ def build_optimizer(model, config):
19
+ layer_decay = getattr(config.optimizer, "layer_decay", 1.0)
20
+ param_groups = get_param_groups(
21
+ model,
22
+ base_lr=config.optimizer.backbone_lr,
23
+ head_lr=config.optimizer.head_lr,
24
+ weight_decay=config.optimizer.weight_decay,
25
+ )
26
+
27
+ if config.optimizer.name.lower() == "adamw":
28
+ if layer_decay == 1:
29
+ optimizer = torch.optim.AdamW(param_groups)
30
+ else:
31
+ optimizer = create_optimizer_v2(
32
+ model,
33
+ opt="adamw",
34
+ lr=config.optimizer.head_lr,
35
+ layer_decay=layer_decay,
36
+ weight_decay=config.optimizer.weight_decay,
37
+ )
38
+ else:
39
+ optimizer = torch.optim.Adam(param_groups)
40
+
41
+ return optimizer
42
+
43
+
44
+ def build_scheduler(optimizer, config, len_loader):
45
+ if config.scheduler.name.lower() == "cosine":
46
+ return CosineAnnealingLR(
47
+ optimizer, T_max=config.training.epochs, eta_min=config.scheduler.min_lr
48
+ )
49
+ elif config.scheduler.name.lower() == "step":
50
+ return StepLR(optimizer, step_size=3, gamma=0.1)
51
+ elif config.scheduler.name.lower() == "plateau":
52
+ return ReduceLROnPlateau(
53
+ optimizer,
54
+ mode="max",
55
+ factor=0.1,
56
+ patience=3,
57
+ min_lr=config.scheduler.min_lr,
58
+ )
59
+ elif config.scheduler.name.lower() == "cosine_warmup":
60
+ return CosineAnnealingWarmupLR(
61
+ optimizer,
62
+ warmup_steps=config.scheduler.warmup_epochs
63
+ * len_loader
64
+ / config.training.gradient_accumulation_steps,
65
+ total_steps=config.training.epochs
66
+ * len_loader
67
+ / config.training.gradient_accumulation_steps,
68
+ min_lr=config.scheduler.min_lr,
69
+ )
70
+ else:
71
+ return None
72
+
73
+
74
+ def main():
75
+ parser = argparse.ArgumentParser(
76
+ description="Train Plant Disease Classification Baseline"
77
+ )
78
+ parser.add_argument(
79
+ "--config", type=str, default="configs/config.yaml", help="Path to config file"
80
+ )
81
+ parser.add_argument(
82
+ "--resume", type=str, default=None, help="Path to checkpoint to resume from"
83
+ )
84
+ parser.add_argument(
85
+ "--init_weights", type=str, default=None, help="Path to weights for warm start"
86
+ )
87
+ args = parser.parse_args()
88
+
89
+ config = load_config(args.config)
90
+
91
+ set_seed(config.seed, deterministic=True)
92
+
93
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
94
+ print(f"Environment: Using device {device}")
95
+
96
+ train_loader, val_loader, num_classes = get_dataloaders(config)
97
+
98
+ if num_classes == 0:
99
+ print(
100
+ "WARNING: No data found. Make sure your datasets are correctly structured."
101
+ )
102
+ # Fallback to prevent immediate crash if no data is present yet
103
+ num_classes = 1
104
+
105
+ config.model.num_classes = num_classes
106
+
107
+ model = PlantDiseaseModel(config, num_classes=num_classes)
108
+ model.to(device)
109
+
110
+ if args.init_weights and os.path.exists(args.init_weights):
111
+ print(f"Warm starting from weights: {args.init_weights}")
112
+ checkpoint = torch.load(args.init_weights, map_location=device)
113
+ state_dict = checkpoint.get("state_dict", checkpoint)
114
+ model.load_state_dict(state_dict)
115
+
116
+ optimizer = build_optimizer(model, config)
117
+ criterion = get_criterion(config)
118
+ scheduler = build_scheduler(optimizer, config, len(train_loader))
119
+
120
+ # resume Logic
121
+ start_epoch = 1
122
+ checkpoint = None
123
+ run_id = None
124
+ if args.resume and os.path.exists(args.resume):
125
+ print(f"Resuming experiment from checkpoint: {args.resume}")
126
+ checkpoint = torch.load(args.resume, map_location=device)
127
+ model.load_state_dict(checkpoint["state_dict"])
128
+ optimizer.load_state_dict(checkpoint["optimizer"])
129
+ if scheduler and checkpoint["scheduler"]:
130
+ scheduler.load_state_dict(checkpoint["scheduler"])
131
+ start_epoch = checkpoint["epoch"] + 1
132
+
133
+ if "rng_states" in checkpoint:
134
+ torch.set_rng_state(checkpoint["rng_states"]["torch"].cpu())
135
+ if device.type == "cuda" and checkpoint["rng_states"]["cuda"] is not None:
136
+ torch.cuda.set_rng_state_all(
137
+ [s.cpu() for s in checkpoint["rng_states"]["cuda"]]
138
+ )
139
+
140
+ if config.logging.use_wandb:
141
+ run_id = checkpoint.get("wandb_run_id")
142
+
143
+ if start_epoch > config.training.epochs:
144
+ print(
145
+ f"Requested to resume at epoch {start_epoch}, but total epochs is {config.training.epochs}. Exiting."
146
+ )
147
+ return
148
+
149
+ # Wandb tracking
150
+ if config.logging.use_wandb:
151
+ wandb_config = OmegaConf.to_container(config, resolve=True)
152
+ wandb.init(
153
+ project=config.logging.project_name,
154
+ name=config.experiment_name,
155
+ config=wandb_config,
156
+ id=run_id, # Use the loaded ID (or None if brand new)
157
+ resume="allow",
158
+ )
159
+
160
+ trainer = Trainer(
161
+ model=model,
162
+ train_loader=train_loader,
163
+ val_loader=val_loader,
164
+ criterion=criterion,
165
+ optimizer=optimizer,
166
+ scheduler=scheduler,
167
+ config=config,
168
+ device=device,
169
+ )
170
+
171
+ if checkpoint is not None:
172
+ if trainer.use_ema and checkpoint.get("state_dict_ema"):
173
+ trainer.model_ema.module.load_state_dict(checkpoint["state_dict_ema"])
174
+
175
+ if args.resume and os.path.exists(args.resume):
176
+ if checkpoint["scaler"]:
177
+ trainer.scaler.load_state_dict(checkpoint["scaler"])
178
+
179
+ if checkpoint["early_stopping"]:
180
+ trainer.early_stopping.best_score = checkpoint["early_stopping"][
181
+ "best_score"
182
+ ]
183
+ trainer.early_stopping.counter = checkpoint["early_stopping"]["counter"]
184
+ trainer.early_stopping.early_stop = checkpoint["early_stopping"][
185
+ "early_stop"
186
+ ]
187
+
188
+ trainer.fit()
189
+
190
+
191
+ if __name__ == "__main__":
192
+ main()
wandb/run-20260419_175057-4kiikgrp/files/config.yaml ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _wandb:
2
+ value:
3
+ cli_version: 0.25.1
4
+ e:
5
+ 5r5nghwc1gn8jkabzqczvjhk7aa94c6x:
6
+ codePath: train.py
7
+ codePathLocal: train.py
8
+ cpu_count: 2
9
+ cpu_count_logical: 4
10
+ disk:
11
+ /:
12
+ total: "254356226048"
13
+ used: "139972616192"
14
+ email: coretwoduo75@gmail.com
15
+ executable: /home/union_point/miniconda3/bin/python
16
+ host: laptop
17
+ memory:
18
+ total: "7760048128"
19
+ os: Linux-6.19.11-200.fc43.x86_64-x86_64-with-glibc2.42
20
+ program: /home/union_point/ml/Plant-Disease-Classification/train.py
21
+ python: CPython 3.13.2
22
+ root: /home/union_point/ml/Plant-Disease-Classification
23
+ startedAt: "2026-04-19T13:50:57.074183Z"
24
+ writerId: 5r5nghwc1gn8jkabzqczvjhk7aa94c6x
25
+ m: []
26
+ python_version: 3.13.2
27
+ t:
28
+ "1":
29
+ - 1
30
+ - 5
31
+ - 11
32
+ - 41
33
+ - 49
34
+ - 53
35
+ - 63
36
+ - 71
37
+ "2":
38
+ - 1
39
+ - 5
40
+ - 11
41
+ - 41
42
+ - 49
43
+ - 53
44
+ - 63
45
+ - 71
46
+ "3":
47
+ - 13
48
+ - 16
49
+ "4": 3.13.2
50
+ "5": 0.25.1
51
+ "6": 5.3.0
52
+ "12": 0.25.1
53
+ "13": linux-x86_64
54
+ augmentation:
55
+ value:
56
+ cutmix_alpha: 0.5
57
+ mixup_alpha: 0.2
58
+ prob: 0
59
+ data:
60
+ value:
61
+ batch_size: 64
62
+ image_size: 384
63
+ max_weight: 5
64
+ num_workers: 2
65
+ pin_memory: true
66
+ train_dir: data/train
67
+ val_dir: data/val
68
+ weighted_sampling: false
69
+ experiment_name:
70
+ value: dinov3_vit_small_plus_baseline
71
+ logging:
72
+ value:
73
+ checkpoint_dir: ./checkpoints
74
+ project_name: plant-disease-classification
75
+ use_wandb: true
76
+ loss:
77
+ value:
78
+ alpha: 0.25
79
+ gamma: 2
80
+ label_smoothing: 0.1
81
+ name: ce
82
+ model:
83
+ value:
84
+ backbone: vit_small_plus_patch16_dinov3.lvd1689m
85
+ drop_path: 0.1
86
+ dropout: 0.1
87
+ freeze_backbone: true
88
+ freeze_bn: false
89
+ num_classes: 39
90
+ pretrained: true
91
+ optimizer:
92
+ value:
93
+ backbone_lr: 0
94
+ head_lr: 0.0005
95
+ layer_decay: 1
96
+ name: adamw
97
+ weight_decay: 0.0001
98
+ scheduler:
99
+ value:
100
+ min_lr: 1e-06
101
+ name: cosine_warmup
102
+ warmup_epochs: 3
103
+ seed:
104
+ value: 42
105
+ training:
106
+ value:
107
+ clip_grad_norm: 1
108
+ early_stopping_patience: 5
109
+ ema:
110
+ decay: 0.999
111
+ enabled: true
112
+ eval_mode: current
113
+ epochs: 15
114
+ gradient_accumulation_steps: 2
115
+ mixed_precision: true
wandb/run-20260419_175057-4kiikgrp/files/requirements.txt ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ wcwidth==0.2.13
2
+ pure_eval==0.2.3
3
+ traitlets==5.14.3
4
+ tornado==6.4.2
5
+ pyzmq==26.2.1
6
+ psutil==7.0.0
7
+ prompt_toolkit==3.0.50
8
+ parso==0.8.4
9
+ nest-asyncio==1.6.0
10
+ ipython_pygments_lexers==1.1.1
11
+ executing==2.2.0
12
+ decorator==5.2.1
13
+ debugpy==1.8.13
14
+ asttokens==3.0.0
15
+ stack-data==0.6.3
16
+ matplotlib-inline==0.1.7
17
+ jupyter_core==5.7.2
18
+ jedi==0.19.2
19
+ comm==0.2.2
20
+ jupyter_client==8.6.3
21
+ ipython==9.0.2
22
+ ipykernel==6.29.5
23
+ mdit-py-plugins==0.4.2
24
+ jupytext==1.16.7
25
+ threadpoolctl==3.6.0
26
+ pydotplus==2.0.2
27
+ seaborn==0.13.2
28
+ patsy==1.0.1
29
+ statsmodels==0.14.4
30
+ category_encoders==2.8.1
31
+ cvxopt==1.3.2
32
+ mpmath==1.3.0
33
+ sympy==1.13.3
34
+ networkx==3.3
35
+ torchvision==0.22.0+cpu
36
+ graphviz==0.20.3
37
+ torchviz==0.0.3
38
+ tqdm==4.67.1
39
+ lightning-utilities==0.14.3
40
+ torchmetrics==1.7.1
41
+ nltk==3.9.1
42
+ imageio==2.37.0
43
+ tifffile==2025.5.10
44
+ lazy_loader==0.4
45
+ scikit-image==0.25.2
46
+ pypickle==1.1.5
47
+ datazets==1.1.2
48
+ colourmap==1.1.21
49
+ scatterd==1.3.9
50
+ clusteval==2.2.5
51
+ bcubed==1.5
52
+ bcubed-metrics==1.0.1
53
+ pytorch-lightning==2.5.1.post0
54
+ lightning==2.5.1.post0
55
+ opencv-python==4.11.0.86
56
+ opencv-python-headless==4.11.0.86
57
+ qudida==0.0.4
58
+ albumentations==1.1.0
59
+ safetensors==0.5.3
60
+ xxhash==3.5.0
61
+ fsspec==2025.3.0
62
+ dill==0.3.8
63
+ multiprocess==0.70.16
64
+ datasets==3.6.0
65
+ menuinst==2.2.0
66
+ anaconda-anon-usage==0.7.0
67
+ annotated-types==0.6.0
68
+ archspec==0.2.3
69
+ boltons==24.1.0
70
+ Brotli==1.0.9
71
+ distro==1.9.0
72
+ frozendict==2.4.2
73
+ idna==3.7
74
+ jsonpointer==2.1
75
+ libmambapy==2.0.5
76
+ mdurl==0.1.0
77
+ packaging==24.2
78
+ platformdirs==4.3.7
79
+ pluggy==1.5.0
80
+ pycosat==0.6.6
81
+ pycparser==2.21
82
+ Pygments==2.19.1
83
+ PySocks==1.7.1
84
+ ruamel.yaml.clib==0.2.12
85
+ tqdm==4.67.1
86
+ truststore==0.10.0
87
+ wheel==0.45.1
88
+ cffi==1.17.1
89
+ jsonpatch==1.33
90
+ markdown-it-py==2.2.0
91
+ pip==25.0
92
+ ruamel.yaml==0.18.10
93
+ cryptography==44.0.1
94
+ requests==2.32.3
95
+ rich==13.9.4
96
+ zstandard==0.23.0
97
+ conda-content-trust==0.2.0
98
+ conda_package_streaming==0.11.0
99
+ conda-package-handling==2.4.0
100
+ conda==25.3.1
101
+ conda-anaconda-telemetry==0.1.2
102
+ conda-anaconda-tos==0.1.3
103
+ conda-libmamba-solver==25.4.0
104
+ pillow==11.2.1
105
+ imageio==2.37.0
106
+ tifffile==2025.5.10
107
+ lazy_loader==0.4
108
+ scikit-image==0.25.2
109
+ pytz==2025.2
110
+ tzdata==2025.2
111
+ six==1.17.0
112
+ pypickle==1.1.5
113
+ pyparsing==3.2.3
114
+ kiwisolver==1.4.8
115
+ fonttools==4.58.0
116
+ cycler==0.12.1
117
+ contourpy==1.3.2
118
+ python-dateutil==2.9.0.post0
119
+ datazets==1.1.2
120
+ colourmap==1.1.21
121
+ scatterd==1.3.9
122
+ clusteval==2.2.5
123
+ ruff==0.11.11
124
+ PyYAML==6.0.2
125
+ propcache==0.3.1
126
+ multidict==6.4.4
127
+ MarkupSafe==3.0.2
128
+ fsspec==2025.5.1
129
+ frozenlist==1.6.0
130
+ filelock==3.18.0
131
+ attrs==25.3.0
132
+ aiohappyeyeballs==2.6.1
133
+ yarl==1.20.0
134
+ Jinja2==3.1.6
135
+ pytorch-lightning==2.5.1.post0
136
+ lightning==2.5.1.post0
137
+ regex==2025.7.34
138
+ accelerate==1.9.0
139
+ peft==0.17.0
140
+ torchaudio==2.8.0
141
+ standard-chunk==3.13.0
142
+ soxr==0.5.0.post1
143
+ msgpack==1.1.1
144
+ llvmlite==0.44.0
145
+ audioread==3.0.1
146
+ audioop-lts==0.2.2
147
+ standard-sunau==3.13.0
148
+ standard-aifc==3.13.0
149
+ soundfile==0.13.1
150
+ pooch==1.8.2
151
+ numba==0.61.2
152
+ librosa==0.11.0
153
+ pyarrow==21.0.0
154
+ antlr4-python3-runtime==4.9.3
155
+ omegaconf==2.3.0
156
+ protobuf==6.32.0
157
+ smmap==5.0.2
158
+ sentry-sdk==2.35.0
159
+ click==8.2.1
160
+ gitdb==4.0.12
161
+ GitPython==3.1.45
162
+ sniffio==1.3.1
163
+ python-multipart==0.0.21
164
+ jiter==0.12.0
165
+ h11==0.16.0
166
+ anyio==4.12.1
167
+ annotated-doc==0.0.4
168
+ uvicorn==0.40.0
169
+ starlette==0.50.0
170
+ httpcore==1.0.9
171
+ httpx==0.28.1
172
+ fastapi==0.128.0
173
+ openai==2.15.0
174
+ PyMuPDF==1.26.7
175
+ ConfigArgParse==1.7.1
176
+ tenacity==9.1.2
177
+ elevenlabs==2.39.0
178
+ zipp==3.23.0
179
+ types-protobuf==6.32.1.20260221
180
+ shellingham==1.5.4
181
+ PyJWT==2.12.1
182
+ prometheus_client==0.24.1
183
+ opentelemetry-proto==1.39.1
184
+ livekit-blingfire==1.1.0
185
+ grpcio==1.78.0
186
+ googleapis-common-protos==1.73.0
187
+ eval_type_backport==0.3.1
188
+ docstring_parser==0.17.0
189
+ colorama==0.4.6
190
+ certifi==2026.2.25
191
+ av==17.0.0
192
+ aiofiles==25.1.0
193
+ watchfiles==1.1.1
194
+ sounddevice==0.5.5
195
+ opentelemetry-exporter-otlp-proto-common==1.39.1
196
+ livekit-protocol==1.1.3
197
+ importlib_metadata==8.7.1
198
+ typer==0.24.1
199
+ opentelemetry-api==1.39.1
200
+ livekit-api==1.1.0
201
+ opentelemetry-semantic-conventions==0.60b1
202
+ opentelemetry-sdk==1.39.1
203
+ opentelemetry-exporter-otlp-proto-http==1.39.1
204
+ opentelemetry-exporter-otlp-proto-grpc==1.39.1
205
+ opentelemetry-exporter-otlp==1.39.1
206
+ python-dotenv==1.2.2
207
+ flatbuffers==25.12.19
208
+ onnxruntime==1.24.4
209
+ livekit-plugins-turn-detector==1.4.6
210
+ livekit-plugins-silero==1.4.6
211
+ websockets==15.0.1
212
+ livekit-plugins-openai==1.4.6
213
+ ctranslate2==4.7.1
214
+ faster-whisper==1.2.1
215
+ typing_extensions==4.15.0
216
+ pyasn1==0.6.3
217
+ proto-plus==1.27.1
218
+ pyasn1_modules==0.4.2
219
+ grpcio-status==1.78.0
220
+ google-auth==2.49.1
221
+ google-api-core==2.30.0
222
+ google-genai==1.68.0
223
+ google-cloud-texttospeech==2.34.0
224
+ google-cloud-speech==2.37.0
225
+ livekit-plugins-google==1.4.6
226
+ hf-xet==1.4.2
227
+ huggingface_hub==1.7.1
228
+ tokenizers==0.22.2
229
+ transformers==5.3.0
230
+ tiktoken==0.12.0
231
+ whisperlivekit==0.2.20.post1
232
+ livekit==1.0.25
233
+ aiosignal==1.4.0
234
+ aiohttp==3.13.3
235
+ livekit-agents==1.5.0
236
+ livekit-plugins-elevenlabs==1.5.0
237
+ filetype==1.2.0
238
+ dirtyjson==1.0.8
239
+ wrapt==2.1.2
240
+ typing-inspection==0.4.2
241
+ tinytag==2.2.1
242
+ setuptools==82.0.1
243
+ pydantic_core==2.41.5
244
+ mypy_extensions==1.1.0
245
+ marshmallow==3.26.2
246
+ griffelib==2.0.0
247
+ greenlet==3.3.2
248
+ aiosqlite==0.22.1
249
+ typing-inspect==0.9.0
250
+ SQLAlchemy==2.0.48
251
+ pydantic==2.12.5
252
+ griffecli==2.0.0
253
+ Deprecated==1.3.1
254
+ llama-index-instrumentation==0.5.0
255
+ llama_cloud==1.6.0
256
+ griffe==2.0.0
257
+ dataclasses-json==0.6.7
258
+ llama-index-workflows==2.17.0
259
+ banks==2.4.1
260
+ llama-index-core==0.14.18
261
+ llama-parse==0.5.20
262
+ llama-index-llms-openai==0.7.2
263
+ llama-index-indices-managed-llama-cloud==0.11.0
264
+ llama-index-embeddings-openai==0.6.0
265
+ llama-index-readers-llama-parse==0.6.0
266
+ llama-index-cli==0.5.6
267
+ llama-index==0.14.18
268
+ soupsieve==2.8.3
269
+ beautifulsoup4==4.14.3
270
+ sortedcontainers==2.4.0
271
+ wsproto==1.3.2
272
+ websocket-client==1.9.0
273
+ urllib3==2.6.3
274
+ tzlocal==5.3.1
275
+ tld==0.13.2
276
+ outcome==1.3.0.post0
277
+ lxml==6.0.2
278
+ charset-normalizer==3.4.6
279
+ babel==2.18.0
280
+ trio==0.33.0
281
+ lxml_html_clean==0.4.4
282
+ dateparser==1.3.0
283
+ courlan==1.3.2
284
+ webdriver-manager==4.0.2
285
+ trio-websocket==0.12.2
286
+ htmldate==1.9.4
287
+ selenium==4.41.0
288
+ jusText==3.0.2
289
+ trafilatura==2.0.0
290
+ pypdf==6.9.1
291
+ pdfminer.six==20260107
292
+ nvidia-cusparselt-cu12==0.6.3
293
+ triton==3.3.0
294
+ nvidia-nvtx-cu12==12.6.77
295
+ nvidia-nvjitlink-cu12==12.6.85
296
+ nvidia-nccl-cu12==2.26.2
297
+ nvidia-curand-cu12==10.3.7.77
298
+ nvidia-cufile-cu12==1.11.1.6
299
+ nvidia-cuda-runtime-cu12==12.6.77
300
+ nvidia-cuda-nvrtc-cu12==12.6.77
301
+ nvidia-cuda-cupti-cu12==12.6.80
302
+ nvidia-cublas-cu12==12.6.4.1
303
+ nvidia-cusparse-cu12==12.5.4.2
304
+ nvidia-cufft-cu12==11.3.0.4
305
+ nvidia-cudnn-cu12==9.5.1.17
306
+ nvidia-cusolver-cu12==11.7.1.2
307
+ torch==2.7.0
308
+ timm==1.0.26
309
+ wandb==0.25.1
310
+ xgboost==3.2.0
311
+ numpy==2.4.4
312
+ joblib==1.5.3
313
+ scipy==1.17.1
314
+ pandas==3.0.2
315
+ scikit-learn==1.8.0
316
+ matplotlib==3.10.8
317
+ mlxtend==0.24.0
wandb/run-20260419_175057-4kiikgrp/files/wandb-metadata.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "os": "Linux-6.19.11-200.fc43.x86_64-x86_64-with-glibc2.42",
3
+ "python": "CPython 3.13.2",
4
+ "startedAt": "2026-04-19T13:50:57.074183Z",
5
+ "program": "/home/union_point/ml/Plant-Disease-Classification/train.py",
6
+ "codePath": "train.py",
7
+ "codePathLocal": "train.py",
8
+ "email": "coretwoduo75@gmail.com",
9
+ "root": "/home/union_point/ml/Plant-Disease-Classification",
10
+ "host": "laptop",
11
+ "executable": "/home/union_point/miniconda3/bin/python",
12
+ "cpu_count": 2,
13
+ "cpu_count_logical": 4,
14
+ "disk": {
15
+ "/": {
16
+ "total": "254356226048",
17
+ "used": "139972616192"
18
+ }
19
+ },
20
+ "memory": {
21
+ "total": "7760048128"
22
+ },
23
+ "writerId": "5r5nghwc1gn8jkabzqczvjhk7aa94c6x"
24
+ }
wandb/run-20260419_175057-4kiikgrp/files/wandb-summary.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"_wandb":{"runtime":12},"_runtime":12}
wandb/run-20260419_175057-4kiikgrp/run-4kiikgrp.wandb ADDED
Binary file (8.59 kB). View file