Rice Leaf Disease Classifier - Swin Transformer
Model Description
This model performs automated classification of rice leaf diseases using a Swin Transformer Tiny architecture pretrained on ImageNet-1K and fine-tuned on a consolidated rice disease dataset. It is designed to assist agricultural practitioners, researchers, and farmers in early detection of pathological conditions affecting rice crops.
Key Capabilities
- Classifies rice leaf images into 6 disease categories
- Handles field-captured images with varying lighting, angles, and backgrounds
- Outputs confidence scores for each prediction
- Optimized for inference on CPU/GPU with mixed precision support
Model Details
| Property | Value |
|---|---|
| Model Architecture | swin_tiny_patch4_window7_224 (Swin Transformer) |
| Base Pretraining | ImageNet-1K (via timm) |
| Input Resolution | 224×224 pixels |
| Input Channels | 3 (RGB) |
| Number of Classes | 6 |
| Output Format | Logits → Softmax probabilities |
| Framework | PyTorch 2.0+ |
| Precision | FP16 (AMP) supported |
Class Labels
labels = [
"Bacterial Leaf Blight",
"Brown Spot",
"Leaf Blast",
"Sheath Blight",
"Tungro",
"Leaf Scald"
]
Training & Evaluation
Dataset Composition
The model was trained on a merged dataset from three Kaggle sources:
Preprocessing:
- Images resized/cropped to 224×224
- Normalized with ImageNet statistics:
mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225] - Augmentations: RandomResizedCrop, Horizontal/Vertical Flip, ColorJitter
Training Configuration
optimizer: AdamW (lr=1e-4, weight_decay=0.05)
scheduler: OneCycleLR (5% warmup, cosine annealing)
batch_size: 16
epochs: 15
loss: CrossEntropyLoss (label_smoothing=0.1)
mixed_precision: AMP (cuda)
Performance Metrics
| Metric | Value |
|---|---|
| Final Validation Accuracy | 97.11% |
| Training Loss (final epoch) | 0.500 |
| Convergence Epoch | ~12 |
How to Use
Install Dependencies
pip install torch torchvision timm albumentations pillow
Load & Run Inference (PyTorch)
import torch
import timm
from PIL import Image
from torchvision import transforms
import requests
# Load model
model = timm.create_model(
"swin_tiny_patch4_window7_224",
pretrained=False,
num_classes=6
)
model.load_state_dict(
torch.load("rice_model.pth", map_location="cpu", weights_only=True)
)
model.eval()
# Preprocessing
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Predict
def predict(image_path: str) -> dict:
img = Image.open(image_path).convert("RGB")
input_tensor = transform(img).unsqueeze(0)
with torch.no_grad():
logits = model(input_tensor)
probs = torch.softmax(logits, dim=1)[0]
labels = ["Bacterial Leaf Blight", "Brown Spot", "Leaf Blast",
"Sheath Blight", "Tungro", "Leaf Scald"]
return {
"prediction": labels[probs.argmax().item()],
"confidence": probs.max().item(),
"all_scores": dict(zip(labels, probs.tolist()))
}
# Example
result = predict("sample_leaf.jpg")
print(f"{result['prediction']}: {result['confidence']:.2%}")
License
This model is released under the MIT License.
Datasets are subject to their original Kaggle licenses—please review source terms before commercial use.
Acknowledgments
- Ross Wightman for the
timmlibrary and Swin Transformer implementations - Kaggle dataset contributors: @anshulm257, @nirmalsankalana, @vbookshelf
- Hugging Face for the model hosting infrastructure
- Google Gemini For assistance in code generation and analysis.
Model developed for agricultural AI research. Not intended for standalone commercial deployment without validation.
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support
