|
|
--- |
|
|
license: mit |
|
|
pipeline_tag: image-classification |
|
|
library_name: pytorch |
|
|
base_model: microsoft/swin-small-patch4-window7-224 |
|
|
metrics: |
|
|
- accuracy |
|
|
- f1 |
|
|
- auc |
|
|
tags: |
|
|
- swin-transformer |
|
|
- timm |
|
|
- image-classification |
|
|
- plant-disease |
|
|
- tea-leaf |
|
|
- rgb-hsv |
|
|
- color-aware |
|
|
datasets: |
|
|
- tea-leaf-disease |
|
|
language: |
|
|
- en |
|
|
--- |
|
|
|
|
|
# Swin Transformer (RGB + HSV) for Tea Leaf Disease Classification π±π |
|
|
|
|
|
This repository provides a **Swin Transformer Small** model fine-tuned for **tea leaf disease classification** using a **color-aware RGB + HSV fusion** strategy. |
|
|
The model achieves **strong generalization performance** with high accuracy and AUC on the test set. |
|
|
|
|
|
--- |
|
|
|
|
|
## π§ Model Overview |
|
|
|
|
|
- **Architecture:** Swin Transformer Small (`swin_small_patch4_window7_224`) |
|
|
- **Pretrained:** Yes (ImageNet) |
|
|
- **Input:** RGB + HSV |
|
|
- **HSV Fusion:** Raw HSV channels (no sin/cos encoding) |
|
|
- **Gating:** Vector gate (disabled in this run) |
|
|
- **DropPath:** 0.2 |
|
|
- **EMA:** Enabled |
|
|
- **AMP:** Enabled |
|
|
- **Framework:** PyTorch (timm-style training) |
|
|
|
|
|
--- |
|
|
|
|
|
## π Model Complexity |
|
|
|
|
|
| Metric | Value | |
|
|
|------|------| |
|
|
| Parameters | **49.47M** | |
|
|
| GFLOPs | **17.16** | |
|
|
| Weights size | **~200 MB** | |
|
|
|
|
|
--- |
|
|
|
|
|
## π Final Test Performance |
|
|
|
|
|
Evaluation performed using **EMA weights from the best checkpoint (epoch 93)**. |
|
|
|
|
|
| Metric | Score | |
|
|
|------|------| |
|
|
| **Top-1 Accuracy** | **96.01%** | |
|
|
| **Macro-F1** | **95.51%** | |
|
|
| **Macro-AUC** | **99.59%** | |
|
|
|
|
|
**Benchmark details** |
|
|
- Test images: **212** |
|
|
- Total inference time: **2.30s** |
|
|
- Throughput: **92.3 images/sec** |
|
|
|
|
|
--- |
|
|
|
|
|
## π Inference Speed |
|
|
|
|
|
- **Post-warmup forward-only** |
|
|
- **92.3 img/s** on GPU |
|
|
|
|
|
--- |
|
|
|
|
|
## ποΈ Training Details |
|
|
|
|
|
- **Experiment name:** `swin_small_hsv_raw` |
|
|
- **Device:** CUDA |
|
|
- **Epochs:** 100 |
|
|
- **Best checkpoint:** Epoch 93 |
|
|
- **Gradient accumulation:** 1 |
|
|
- **HSV gate warmup:** 5 epochs |
|
|
|
|
|
--- |
|
|
|
|
|
## π¦ Model Files |
|
|
|
|
|
- `model.safetensors` β final EMA weights (recommended) |
|
|
- Config and training artifacts included in repository |
|
|
|
|
|
--- |
|
|
|
|
|
## π§ͺ Intended Use |
|
|
|
|
|
This model is designed for: |
|
|
- Tea leaf disease classification |
|
|
- Agricultural decision-support systems |
|
|
- Research on color-aware vision transformers |
|
|
|
|
|
β οΈ **Not intended as a medical or agronomic diagnostic tool.** |
|
|
|
|
|
--- |
|
|
|
|
|
## β οΈ Limitations |
|
|
|
|
|
- Dataset-specific bias may exist |
|
|
--- |
|
|
|
|
|
## π§βπ» How to Use (PyTorch + timm) |
|
|
|
|
|
```python |
|
|
import timm |
|
|
import torch |
|
|
from PIL import Image |
|
|
from torchvision import transforms |
|
|
|
|
|
# Create model |
|
|
model = timm.create_model( |
|
|
"swin_small_patch4_window7_224", |
|
|
pretrained=False, |
|
|
num_classes=NUM_CLASSES |
|
|
) |
|
|
|
|
|
# Load weights |
|
|
state = torch.load("model.safetensors", map_location="cpu") |
|
|
model.load_state_dict(state, strict=False) |
|
|
model.eval() |
|
|
|
|
|
# Preprocessing |
|
|
transform = transforms.Compose([ |
|
|
transforms.Resize(256), |
|
|
transforms.CenterCrop(224), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize( |
|
|
mean=(0.485, 0.456, 0.406), |
|
|
std=(0.229, 0.224, 0.225) |
|
|
) |
|
|
]) |
|
|
|
|
|
img = Image.open("tea_leaf.jpg").convert("RGB") |
|
|
x = transform(img).unsqueeze(0) |
|
|
|
|
|
with torch.no_grad(): |
|
|
logits = model(x) |
|
|
pred = logits.argmax(dim=1).item() |
|
|
|
|
|
print("Predicted class:", pred) |
|
|
|