File size: 3,127 Bytes
170840a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
---
license: mit
pipeline_tag: image-classification
library_name: pytorch
base_model: microsoft/swin-small-patch4-window7-224
metrics:
- accuracy
- f1
- auc
tags:
- swin-transformer
- timm
- image-classification
- plant-disease
- tea-leaf
- rgb-hsv
- color-aware
datasets:
- tea-leaf-disease
language:
- en
---
# Swin Transformer (RGB + HSV) for Tea Leaf Disease Classification π±π
This repository provides a **Swin Transformer Small** model fine-tuned for **tea leaf disease classification** using a **color-aware RGB + HSV fusion** strategy.
The model achieves **strong generalization performance** with high accuracy and AUC on the test set.
---
## π§ Model Overview
- **Architecture:** Swin Transformer Small (`swin_small_patch4_window7_224`)
- **Pretrained:** Yes (ImageNet)
- **Input:** RGB + HSV
- **HSV Fusion:** Raw HSV channels (no sin/cos encoding)
- **Gating:** Vector gate (disabled in this run)
- **DropPath:** 0.2
- **EMA:** Enabled
- **AMP:** Enabled
- **Framework:** PyTorch (timm-style training)
---
## π Model Complexity
| Metric | Value |
|------|------|
| Parameters | **49.47M** |
| GFLOPs | **17.16** |
| Weights size | **~200 MB** |
---
## π Final Test Performance
Evaluation performed using **EMA weights from the best checkpoint (epoch 93)**.
| Metric | Score |
|------|------|
| **Top-1 Accuracy** | **96.01%** |
| **Macro-F1** | **95.51%** |
| **Macro-AUC** | **99.59%** |
**Benchmark details**
- Test images: **212**
- Total inference time: **2.30s**
- Throughput: **92.3 images/sec**
---
## π Inference Speed
- **Post-warmup forward-only**
- **92.3 img/s** on GPU
---
## ποΈ Training Details
- **Experiment name:** `swin_small_hsv_raw`
- **Device:** CUDA
- **Epochs:** 100
- **Best checkpoint:** Epoch 93
- **Gradient accumulation:** 1
- **HSV gate warmup:** 5 epochs
---
## π¦ Model Files
- `model.safetensors` β final EMA weights (recommended)
- Config and training artifacts included in repository
---
## π§ͺ Intended Use
This model is designed for:
- Tea leaf disease classification
- Agricultural decision-support systems
- Research on color-aware vision transformers
β οΈ **Not intended as a medical or agronomic diagnostic tool.**
---
## β οΈ Limitations
- Dataset-specific bias may exist
---
## π§βπ» How to Use (PyTorch + timm)
```python
import timm
import torch
from PIL import Image
from torchvision import transforms
# Create model
model = timm.create_model(
"swin_small_patch4_window7_224",
pretrained=False,
num_classes=NUM_CLASSES
)
# Load weights
state = torch.load("model.safetensors", map_location="cpu")
model.load_state_dict(state, strict=False)
model.eval()
# Preprocessing
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225)
)
])
img = Image.open("tea_leaf.jpg").convert("RGB")
x = transform(img).unsqueeze(0)
with torch.no_grad():
logits = model(x)
pred = logits.argmax(dim=1).item()
print("Predicted class:", pred)
|