--- license: mit pipeline_tag: image-classification library_name: pytorch base_model: microsoft/swin-small-patch4-window7-224 metrics: - accuracy - f1 - auc tags: - swin-transformer - timm - image-classification - plant-disease - tea-leaf - rgb-hsv - color-aware datasets: - tea-leaf-disease language: - en --- # Swin Transformer (RGB + HSV) for Tea Leaf Disease Classification ๐ŸŒฑ๐Ÿƒ This repository provides a **Swin Transformer Small** model fine-tuned for **tea leaf disease classification** using a **color-aware RGB + HSV fusion** strategy. The model achieves **strong generalization performance** with high accuracy and AUC on the test set. --- ## ๐Ÿง  Model Overview - **Architecture:** Swin Transformer Small (`swin_small_patch4_window7_224`) - **Pretrained:** Yes (ImageNet) - **Input:** RGB + HSV - **HSV Fusion:** Raw HSV channels (no sin/cos encoding) - **Gating:** Vector gate (disabled in this run) - **DropPath:** 0.2 - **EMA:** Enabled - **AMP:** Enabled - **Framework:** PyTorch (timm-style training) --- ## ๐Ÿ“ Model Complexity | Metric | Value | |------|------| | Parameters | **49.47M** | | GFLOPs | **17.16** | | Weights size | **~200 MB** | --- ## ๐Ÿ“Š Final Test Performance Evaluation performed using **EMA weights from the best checkpoint (epoch 93)**. | Metric | Score | |------|------| | **Top-1 Accuracy** | **96.01%** | | **Macro-F1** | **95.51%** | | **Macro-AUC** | **99.59%** | **Benchmark details** - Test images: **212** - Total inference time: **2.30s** - Throughput: **92.3 images/sec** --- ## ๐Ÿš€ Inference Speed - **Post-warmup forward-only** - **92.3 img/s** on GPU --- ## ๐Ÿ—‚๏ธ Training Details - **Experiment name:** `swin_small_hsv_raw` - **Device:** CUDA - **Epochs:** 100 - **Best checkpoint:** Epoch 93 - **Gradient accumulation:** 1 - **HSV gate warmup:** 5 epochs --- ## ๐Ÿ“ฆ Model Files - `model.safetensors` โ€” final EMA weights (recommended) - Config and training artifacts included in repository --- ## ๐Ÿงช Intended Use This model is designed for: - Tea leaf disease classification - Agricultural decision-support systems - Research on color-aware vision transformers โš ๏ธ **Not intended as a medical or agronomic diagnostic tool.** --- ## โš ๏ธ Limitations - Dataset-specific bias may exist --- ## ๐Ÿง‘โ€๐Ÿ’ป How to Use (PyTorch + timm) ```python import timm import torch from PIL import Image from torchvision import transforms # Create model model = timm.create_model( "swin_small_patch4_window7_224", pretrained=False, num_classes=NUM_CLASSES ) # Load weights state = torch.load("model.safetensors", map_location="cpu") model.load_state_dict(state, strict=False) model.eval() # Preprocessing transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize( mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225) ) ]) img = Image.open("tea_leaf.jpg").convert("RGB") x = transform(img).unsqueeze(0) with torch.no_grad(): logits = model(x) pred = logits.argmax(dim=1).item() print("Predicted class:", pred)