tea_leaf_disease / README.md
saifullah03's picture
Update README.md
1a27937 verified
metadata
license: mit
pipeline_tag: image-classification
library_name: pytorch
base_model: microsoft/swin-small-patch4-window7-224
metrics:
  - accuracy
  - f1
  - auc
tags:
  - swin-transformer
  - timm
  - image-classification
  - plant-disease
  - tea-leaf
  - rgb-hsv
  - color-aware
datasets:
  - tea-leaf-disease
language:
  - en

Swin Transformer (RGB + HSV) for Tea Leaf Disease Classification πŸŒ±πŸƒ

This repository provides a Swin Transformer Small model fine-tuned for tea leaf disease classification using a color-aware RGB + HSV fusion strategy.
The model achieves strong generalization performance with high accuracy and AUC on the test set.


🧠 Model Overview

  • Architecture: Swin Transformer Small (swin_small_patch4_window7_224)
  • Pretrained: Yes (ImageNet)
  • Input: RGB + HSV
  • HSV Fusion: Raw HSV channels (no sin/cos encoding)
  • Gating: Vector gate (disabled in this run)
  • DropPath: 0.2
  • EMA: Enabled
  • AMP: Enabled
  • Framework: PyTorch (timm-style training)

πŸ“ Model Complexity

Metric Value
Parameters 49.47M
GFLOPs 17.16
Weights size ~200 MB

πŸ“Š Final Test Performance

Evaluation performed using EMA weights from the best checkpoint (epoch 93).

Metric Score
Top-1 Accuracy 96.01%
Macro-F1 95.51%
Macro-AUC 99.59%

Benchmark details

  • Test images: 212
  • Total inference time: 2.30s
  • Throughput: 92.3 images/sec

πŸš€ Inference Speed

  • Post-warmup forward-only
  • 92.3 img/s on GPU

πŸ—‚οΈ Training Details

  • Experiment name: swin_small_hsv_raw
  • Device: CUDA
  • Epochs: 100
  • Best checkpoint: Epoch 93
  • Gradient accumulation: 1
  • HSV gate warmup: 5 epochs

πŸ“¦ Model Files

  • model.safetensors β€” final EMA weights (recommended)
  • Config and training artifacts included in repository

πŸ§ͺ Intended Use

This model is designed for:

  • Tea leaf disease classification
  • Agricultural decision-support systems
  • Research on color-aware vision transformers

⚠️ Not intended as a medical or agronomic diagnostic tool.


⚠️ Limitations

  • Dataset-specific bias may exist

πŸ§‘β€πŸ’» How to Use (PyTorch + timm)

import timm
import torch
from PIL import Image
from torchvision import transforms

# Create model
model = timm.create_model(
    "swin_small_patch4_window7_224",
    pretrained=False,
    num_classes=NUM_CLASSES
)

# Load weights
state = torch.load("model.safetensors", map_location="cpu")
model.load_state_dict(state, strict=False)
model.eval()

# Preprocessing
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=(0.485, 0.456, 0.406),
        std=(0.229, 0.224, 0.225)
    )
])

img = Image.open("tea_leaf.jpg").convert("RGB")
x = transform(img).unsqueeze(0)

with torch.no_grad():
    logits = model(x)
pred = logits.argmax(dim=1).item()

print("Predicted class:", pred)