tea_leaf_disease / README.md
saifullah03's picture
Update README.md
1a27937 verified
---
license: mit
pipeline_tag: image-classification
library_name: pytorch
base_model: microsoft/swin-small-patch4-window7-224
metrics:
- accuracy
- f1
- auc
tags:
- swin-transformer
- timm
- image-classification
- plant-disease
- tea-leaf
- rgb-hsv
- color-aware
datasets:
- tea-leaf-disease
language:
- en
---
# Swin Transformer (RGB + HSV) for Tea Leaf Disease Classification πŸŒ±πŸƒ
This repository provides a **Swin Transformer Small** model fine-tuned for **tea leaf disease classification** using a **color-aware RGB + HSV fusion** strategy.
The model achieves **strong generalization performance** with high accuracy and AUC on the test set.
---
## 🧠 Model Overview
- **Architecture:** Swin Transformer Small (`swin_small_patch4_window7_224`)
- **Pretrained:** Yes (ImageNet)
- **Input:** RGB + HSV
- **HSV Fusion:** Raw HSV channels (no sin/cos encoding)
- **Gating:** Vector gate (disabled in this run)
- **DropPath:** 0.2
- **EMA:** Enabled
- **AMP:** Enabled
- **Framework:** PyTorch (timm-style training)
---
## πŸ“ Model Complexity
| Metric | Value |
|------|------|
| Parameters | **49.47M** |
| GFLOPs | **17.16** |
| Weights size | **~200 MB** |
---
## πŸ“Š Final Test Performance
Evaluation performed using **EMA weights from the best checkpoint (epoch 93)**.
| Metric | Score |
|------|------|
| **Top-1 Accuracy** | **96.01%** |
| **Macro-F1** | **95.51%** |
| **Macro-AUC** | **99.59%** |
**Benchmark details**
- Test images: **212**
- Total inference time: **2.30s**
- Throughput: **92.3 images/sec**
---
## πŸš€ Inference Speed
- **Post-warmup forward-only**
- **92.3 img/s** on GPU
---
## πŸ—‚οΈ Training Details
- **Experiment name:** `swin_small_hsv_raw`
- **Device:** CUDA
- **Epochs:** 100
- **Best checkpoint:** Epoch 93
- **Gradient accumulation:** 1
- **HSV gate warmup:** 5 epochs
---
## πŸ“¦ Model Files
- `model.safetensors` β€” final EMA weights (recommended)
- Config and training artifacts included in repository
---
## πŸ§ͺ Intended Use
This model is designed for:
- Tea leaf disease classification
- Agricultural decision-support systems
- Research on color-aware vision transformers
⚠️ **Not intended as a medical or agronomic diagnostic tool.**
---
## ⚠️ Limitations
- Dataset-specific bias may exist
---
## πŸ§‘β€πŸ’» How to Use (PyTorch + timm)
```python
import timm
import torch
from PIL import Image
from torchvision import transforms
# Create model
model = timm.create_model(
"swin_small_patch4_window7_224",
pretrained=False,
num_classes=NUM_CLASSES
)
# Load weights
state = torch.load("model.safetensors", map_location="cpu")
model.load_state_dict(state, strict=False)
model.eval()
# Preprocessing
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225)
)
])
img = Image.open("tea_leaf.jpg").convert("RGB")
x = transform(img).unsqueeze(0)
with torch.no_grad():
logits = model(x)
pred = logits.argmax(dim=1).item()
print("Predicted class:", pred)