YAML Metadata Warning:empty or missing yaml metadata in repo card

Check out the documentation for more information.

Plant Disease Classification

A robust, configurable deep learning pipeline for plant disease classification using PyTorch. This project leverages timm for a vast array of pre-trained backbones (e.g., EfficientNetV2, ConvNeXtV2, EVA02) and offers advanced training features such as Exponential Moving Average (EMA) for weights, Layer-wise Learning Rate Decay (LLRD), MixUp/CutMix data augmentation, and Weights & Biases (W&B) integration for experiment tracking.

  • Web Interface:
  • REST API Documentation: s

Features

  • Extensive Model Support: Easily swap backbones by changing the config, enabled by integration with timm.
  • Advanced Training Techniques:
    • Model EMA (Exponential Moving Average) to stabilize training and improve generalization.
    • Layer-wise Learning Rate Decay (LLRD) for optimal fine-tuning of transformer and CNN architectures like vit, convnextv2.
    • Mixed Precision Training for faster execution and lower memory footprint.
    • Gradient Accumulation.
  • Data Augmentation: MixUp and CutMix integrations for regularization.
  • Customizable Configuration: Highly modular experiment setups using omegaconf (YAML config files).
  • Experiment Tracking: Full integration with Weights & Biases logging everything from hyperparameter configs to validation metrics.

Results

Model mAP Accuracy
EfficientNetV2 Small 0.87 0.815
DINOv3 ViT Small Plus 0.91 0.830
ConvNeXtV2 Tiny 0.94 0.860

Project Structure

Plant-Disease-Classification/
β”œβ”€β”€ configs/
β”‚   └── config.yaml          # Main configuration file
β”œβ”€β”€ data/
β”‚   β”œβ”€β”€ train/               # Train data (organized by class folders)
β”‚   └── val/                 # Val data (organized by class folders)
β”œβ”€β”€ src/
β”‚   β”œβ”€β”€ dataset.py           # Dataloaders and augmentation logic
β”‚   β”œβ”€β”€ infer.py             # Inference script and prediction utilities
β”‚   β”œβ”€β”€ loss.py              # Loss functions (CrossEntropy, Focal Loss)
β”‚   β”œβ”€β”€ metrics.py           # Metric calculations
β”‚   β”œβ”€β”€ models.py            # Model definitions and param groupings
β”‚   β”œβ”€β”€ trainer.py           # Core training loop
β”‚   └── utils.py             # Helpers (schedulers, seeds, config loading)
β”œβ”€β”€ train.py                 # Main entrypoint for training
└── requirements.txt         # Project dependencies

Quick Start

1. Environment Setup

It is highly recommended to use uv for fast, reliable package management.

# Create a virtual environment using uv
uv venv

# Activate the environment
source .venv/bin/activate  # Linux/MacOS

# Install dependencies rapidly
uv pip install -r requirements.txt

2. Prepare Data

Ensure your dataset is arranged in PyTorch ImageFolder format. Place the training data in data/train and validation data in data/val. Each subplot or leaf should be in its corresponding disease or health category folder.

data/
└── train/
    β”œβ”€β”€ Apple scab/
    └── ...

3. Provide Configuration

Modify the hyperparameters, model choices, and paths inside configs/config.yaml.

4. Train the Model

Run the training pipeline:

python train.py --config configs/config.yaml

Resuming Training: To resume from an existing checkpoint, pass the --resume argument:

python train.py --config configs/config.yaml --resume checkpoints/checkpoint.pth

To load weights for a warm start (e.g., finetuning), use:

python train.py --config configs/config.yaml --init_weights weights/pretrained.pth

5. Inference

You can run inference on a single image using the src/infer.py script. The script requires a serialized TorchScript model checkpoint.

# Basic inference
python src/infer.py --image_path path/to/leaf.jpg --checkpoint checkpoints/best_model.pt --image_size 384

# Inference with Test Time Augmentation (TTA)
python src/infer.py --image_path path/to/leaf.jpg --checkpoint checkpoints/best_model.pt --image_size 384 --tta

Note: The inference script expects a data/label_map.json file to map class indices to disease names.

Documentation

Model Selection

By default, the pipeline uses timm.create_model(...). You can specify any model architecture available in timm (e.g. convnextv2_base, efficientnet_b0, eva02_base_patch14_448) directly in the config.yaml file under model.backbone.

Configuration Details

The pipeline uses OmegaConf. Hyperparameters such as loss, optimizer, and augmentation can be tweaked. For example, to enable layer-wise learning rate decay, adjust optimizer.layer_decay to a value < 1.0.

Logging & Checkpoints

  • Checkpoints are saved under the checkpoints/ directory (customizable via logging.checkpoint_dir).
  • Best model checkpoints (current and EMA) are tracked based on the monitored validation metric.
  • When logging.use_wandb is true, the script initializes a Weights & Biases run, logging train/validation losses and selected metrics seamlessly.

Model Weights


The trained weights are hosted on Hugging Face

Technical Report

A comprehensive report results is included in the repository.

View Technical Report (PDF)

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support