Swin Transformer (RGB + HSV) for Tea Leaf Disease Classification πŸŒ±πŸƒ

This repository provides a Swin Transformer Small model fine-tuned for tea leaf disease classification using a color-aware RGB + HSV fusion strategy.
The model achieves strong generalization performance with high accuracy and AUC on the test set.


🧠 Model Overview

  • Architecture: Swin Transformer Small (swin_small_patch4_window7_224)
  • Pretrained: Yes (ImageNet)
  • Input: RGB + HSV
  • HSV Fusion: Raw HSV channels (no sin/cos encoding)
  • Gating: Vector gate (disabled in this run)
  • DropPath: 0.2
  • EMA: Enabled
  • AMP: Enabled
  • Framework: PyTorch (timm-style training)

πŸ“ Model Complexity

Metric Value
Parameters 49.47M
GFLOPs 17.16
Weights size ~200 MB

πŸ“Š Final Test Performance

Evaluation performed using EMA weights from the best checkpoint (epoch 93).

Metric Score
Top-1 Accuracy 96.01%
Macro-F1 95.51%
Macro-AUC 99.59%

Benchmark details

  • Test images: 212
  • Total inference time: 2.30s
  • Throughput: 92.3 images/sec

πŸš€ Inference Speed

  • Post-warmup forward-only
  • 92.3 img/s on GPU

πŸ—‚οΈ Training Details

  • Experiment name: swin_small_hsv_raw
  • Device: CUDA
  • Epochs: 100
  • Best checkpoint: Epoch 93
  • Gradient accumulation: 1
  • HSV gate warmup: 5 epochs

πŸ“¦ Model Files

  • model.safetensors β€” final EMA weights (recommended)
  • Config and training artifacts included in repository

πŸ§ͺ Intended Use

This model is designed for:

  • Tea leaf disease classification
  • Agricultural decision-support systems
  • Research on color-aware vision transformers

⚠️ Not intended as a medical or agronomic diagnostic tool.


⚠️ Limitations

  • Dataset-specific bias may exist

πŸ§‘β€πŸ’» How to Use (PyTorch + timm)

import timm
import torch
from PIL import Image
from torchvision import transforms

# Create model
model = timm.create_model(
    "swin_small_patch4_window7_224",
    pretrained=False,
    num_classes=NUM_CLASSES
)

# Load weights
state = torch.load("model.safetensors", map_location="cpu")
model.load_state_dict(state, strict=False)
model.eval()

# Preprocessing
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=(0.485, 0.456, 0.406),
        std=(0.229, 0.224, 0.225)
    )
])

img = Image.open("tea_leaf.jpg").convert("RGB")
x = transform(img).unsqueeze(0)

with torch.no_grad():
    logits = model(x)
pred = logits.argmax(dim=1).item()

print("Predicted class:", pred)
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for saifullah03/tea_leaf_disease

Finetuned
(9)
this model