File size: 3,127 Bytes
170840a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
---
license: mit
pipeline_tag: image-classification
library_name: pytorch
base_model: microsoft/swin-small-patch4-window7-224
metrics:
- accuracy
- f1
- auc
tags:
- swin-transformer
- timm
- image-classification
- plant-disease
- tea-leaf
- rgb-hsv
- color-aware
datasets:
- tea-leaf-disease
language:
- en
---

# Swin Transformer (RGB + HSV) for Tea Leaf Disease Classification πŸŒ±πŸƒ

This repository provides a **Swin Transformer Small** model fine-tuned for **tea leaf disease classification** using a **color-aware RGB + HSV fusion** strategy.  
The model achieves **strong generalization performance** with high accuracy and AUC on the test set.

---

## 🧠 Model Overview

- **Architecture:** Swin Transformer Small (`swin_small_patch4_window7_224`)
- **Pretrained:** Yes (ImageNet)
- **Input:** RGB + HSV
- **HSV Fusion:** Raw HSV channels (no sin/cos encoding)
- **Gating:** Vector gate (disabled in this run)
- **DropPath:** 0.2
- **EMA:** Enabled
- **AMP:** Enabled
- **Framework:** PyTorch (timm-style training)

---

## πŸ“ Model Complexity

| Metric | Value |
|------|------|
| Parameters | **49.47M** |
| GFLOPs | **17.16** |
| Weights size | **~200 MB** |

---

## πŸ“Š Final Test Performance

Evaluation performed using **EMA weights from the best checkpoint (epoch 93)**.

| Metric | Score |
|------|------|
| **Top-1 Accuracy** | **96.01%** |
| **Macro-F1** | **95.51%** |
| **Macro-AUC** | **99.59%** |

**Benchmark details**
- Test images: **212**
- Total inference time: **2.30s**
- Throughput: **92.3 images/sec**

---

## πŸš€ Inference Speed

- **Post-warmup forward-only**
- **92.3 img/s** on GPU

---

## πŸ—‚οΈ Training Details

- **Experiment name:** `swin_small_hsv_raw`
- **Device:** CUDA
- **Epochs:** 100
- **Best checkpoint:** Epoch 93
- **Gradient accumulation:** 1
- **HSV gate warmup:** 5 epochs

---

## πŸ“¦ Model Files

- `model.safetensors` β€” final EMA weights (recommended)
- Config and training artifacts included in repository

---

## πŸ§ͺ Intended Use

This model is designed for:
- Tea leaf disease classification
- Agricultural decision-support systems
- Research on color-aware vision transformers

⚠️ **Not intended as a medical or agronomic diagnostic tool.**

---

## ⚠️ Limitations

- Dataset-specific bias may exist
---

## πŸ§‘β€πŸ’» How to Use (PyTorch + timm)

```python
import timm
import torch
from PIL import Image
from torchvision import transforms

# Create model
model = timm.create_model(
    "swin_small_patch4_window7_224",
    pretrained=False,
    num_classes=NUM_CLASSES
)

# Load weights
state = torch.load("model.safetensors", map_location="cpu")
model.load_state_dict(state, strict=False)
model.eval()

# Preprocessing
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=(0.485, 0.456, 0.406),
        std=(0.229, 0.224, 0.225)
    )
])

img = Image.open("tea_leaf.jpg").convert("RGB")
x = transform(img).unsqueeze(0)

with torch.no_grad():
    logits = model(x)
pred = logits.argmax(dim=1).item()

print("Predicted class:", pred)