Swin Transformer (RGB + HSV) for Tea Leaf Disease Classification π±π
This repository provides a Swin Transformer Small model fine-tuned for tea leaf disease classification using a color-aware RGB + HSV fusion strategy.
The model achieves strong generalization performance with high accuracy and AUC on the test set.
π§ Model Overview
- Architecture: Swin Transformer Small (
swin_small_patch4_window7_224) - Pretrained: Yes (ImageNet)
- Input: RGB + HSV
- HSV Fusion: Raw HSV channels (no sin/cos encoding)
- Gating: Vector gate (disabled in this run)
- DropPath: 0.2
- EMA: Enabled
- AMP: Enabled
- Framework: PyTorch (timm-style training)
π Model Complexity
| Metric | Value |
|---|---|
| Parameters | 49.47M |
| GFLOPs | 17.16 |
| Weights size | ~200 MB |
π Final Test Performance
Evaluation performed using EMA weights from the best checkpoint (epoch 93).
| Metric | Score |
|---|---|
| Top-1 Accuracy | 96.01% |
| Macro-F1 | 95.51% |
| Macro-AUC | 99.59% |
Benchmark details
- Test images: 212
- Total inference time: 2.30s
- Throughput: 92.3 images/sec
π Inference Speed
- Post-warmup forward-only
- 92.3 img/s on GPU
ποΈ Training Details
- Experiment name:
swin_small_hsv_raw - Device: CUDA
- Epochs: 100
- Best checkpoint: Epoch 93
- Gradient accumulation: 1
- HSV gate warmup: 5 epochs
π¦ Model Files
model.safetensorsβ final EMA weights (recommended)- Config and training artifacts included in repository
π§ͺ Intended Use
This model is designed for:
- Tea leaf disease classification
- Agricultural decision-support systems
- Research on color-aware vision transformers
β οΈ Not intended as a medical or agronomic diagnostic tool.
β οΈ Limitations
- Dataset-specific bias may exist
π§βπ» How to Use (PyTorch + timm)
import timm
import torch
from PIL import Image
from torchvision import transforms
# Create model
model = timm.create_model(
"swin_small_patch4_window7_224",
pretrained=False,
num_classes=NUM_CLASSES
)
# Load weights
state = torch.load("model.safetensors", map_location="cpu")
model.load_state_dict(state, strict=False)
model.eval()
# Preprocessing
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225)
)
])
img = Image.open("tea_leaf.jpg").convert("RGB")
x = transform(img).unsqueeze(0)
with torch.no_grad():
logits = model(x)
pred = logits.argmax(dim=1).item()
print("Predicted class:", pred)
Model tree for saifullah03/tea_leaf_disease
Base model
microsoft/swin-small-patch4-window7-224