saifullah03 commited on
Commit
170840a
Β·
verified Β·
1 Parent(s): fdd066c

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +158 -3
README.md CHANGED
@@ -1,3 +1,158 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ pipeline_tag: image-classification
4
+ library_name: pytorch
5
+ base_model: microsoft/swin-small-patch4-window7-224
6
+ metrics:
7
+ - accuracy
8
+ - f1
9
+ - auc
10
+ tags:
11
+ - swin-transformer
12
+ - timm
13
+ - image-classification
14
+ - plant-disease
15
+ - tea-leaf
16
+ - rgb-hsv
17
+ - color-aware
18
+ datasets:
19
+ - tea-leaf-disease
20
+ language:
21
+ - en
22
+ ---
23
+
24
+ # Swin Transformer (RGB + HSV) for Tea Leaf Disease Classification πŸŒ±πŸƒ
25
+
26
+ This repository provides a **Swin Transformer Small** model fine-tuned for **tea leaf disease classification** using a **color-aware RGB + HSV fusion** strategy.
27
+ The model achieves **strong generalization performance** with high accuracy and AUC on the test set.
28
+
29
+ ---
30
+
31
+ ## 🧠 Model Overview
32
+
33
+ - **Architecture:** Swin Transformer Small (`swin_small_patch4_window7_224`)
34
+ - **Pretrained:** Yes (ImageNet)
35
+ - **Input:** RGB + HSV
36
+ - **HSV Fusion:** Raw HSV channels (no sin/cos encoding)
37
+ - **Gating:** Vector gate (disabled in this run)
38
+ - **DropPath:** 0.2
39
+ - **EMA:** Enabled
40
+ - **AMP:** Enabled
41
+ - **Framework:** PyTorch (timm-style training)
42
+
43
+ ---
44
+
45
+ ## πŸ“ Model Complexity
46
+
47
+ | Metric | Value |
48
+ |------|------|
49
+ | Parameters | **49.47M** |
50
+ | GFLOPs | **17.16** |
51
+ | Weights size | **~200 MB** |
52
+
53
+ ---
54
+
55
+ ## πŸ“Š Final Test Performance
56
+
57
+ Evaluation performed using **EMA weights from the best checkpoint (epoch 93)**.
58
+
59
+ | Metric | Score |
60
+ |------|------|
61
+ | **Top-1 Accuracy** | **96.01%** |
62
+ | **Macro-F1** | **95.51%** |
63
+ | **Macro-AUC** | **99.59%** |
64
+
65
+ **Benchmark details**
66
+ - Test images: **212**
67
+ - Total inference time: **2.30s**
68
+ - Throughput: **92.3 images/sec**
69
+
70
+ ---
71
+
72
+ ## πŸš€ Inference Speed
73
+
74
+ - **Post-warmup forward-only**
75
+ - **92.3 img/s** on GPU
76
+
77
+ ---
78
+
79
+ ## πŸ—‚οΈ Training Details
80
+
81
+ - **Experiment name:** `swin_small_hsv_raw`
82
+ - **Device:** CUDA
83
+ - **Epochs:** 100
84
+ - **Best checkpoint:** Epoch 93
85
+ - **Gradient accumulation:** 1
86
+ - **HSV gate warmup:** 5 epochs
87
+
88
+ ---
89
+
90
+ ## πŸ“¦ Model Files
91
+
92
+ - `model.safetensors` β€” final EMA weights (recommended)
93
+ - Config and training artifacts included in repository
94
+
95
+ ---
96
+
97
+ ## πŸ§ͺ Intended Use
98
+
99
+ This model is designed for:
100
+ - Tea leaf disease classification
101
+ - Agricultural decision-support systems
102
+ - Research on color-aware vision transformers
103
+
104
+ ⚠️ **Not intended as a medical or agronomic diagnostic tool.**
105
+
106
+ ---
107
+
108
+ ## ⚠️ Limitations
109
+
110
+ - Performance may degrade under:
111
+ - extreme lighting changes
112
+ - motion blur
113
+ - unseen disease categories
114
+ - Dataset-specific bias may exist
115
+
116
+ ---
117
+
118
+ ## πŸ§‘β€πŸ’» How to Use (PyTorch + timm)
119
+
120
+ ```python
121
+ import timm
122
+ import torch
123
+ from PIL import Image
124
+ from torchvision import transforms
125
+
126
+ # Create model
127
+ model = timm.create_model(
128
+ "swin_small_patch4_window7_224",
129
+ pretrained=False,
130
+ num_classes=NUM_CLASSES
131
+ )
132
+
133
+ # Load weights
134
+ state = torch.load("model.safetensors", map_location="cpu")
135
+ model.load_state_dict(state, strict=False)
136
+ model.eval()
137
+
138
+ # Preprocessing
139
+ transform = transforms.Compose([
140
+ transforms.Resize(256),
141
+ transforms.CenterCrop(224),
142
+ transforms.ToTensor(),
143
+ transforms.Normalize(
144
+ mean=(0.485, 0.456, 0.406),
145
+ std=(0.229, 0.224, 0.225)
146
+ )
147
+ ])
148
+
149
+ img = Image.open("tea_leaf.jpg").convert("RGB")
150
+ x = transform(img).unsqueeze(0)
151
+
152
+ with torch.no_grad():
153
+ logits = model(x)
154
+ pred = logits.argmax(dim=1).item()
155
+
156
+ print("Predicted class:", pred)
157
+
158
+ ---