Rice Leaf Disease Classifier - Swin Transformer

Model Type Framework License

Model Description

This model performs automated classification of rice leaf diseases using a Swin Transformer Tiny architecture pretrained on ImageNet-1K and fine-tuned on a consolidated rice disease dataset. It is designed to assist agricultural practitioners, researchers, and farmers in early detection of pathological conditions affecting rice crops.

Key Capabilities

  • Classifies rice leaf images into 6 disease categories
  • Handles field-captured images with varying lighting, angles, and backgrounds
  • Outputs confidence scores for each prediction
  • Optimized for inference on CPU/GPU with mixed precision support

Model Details

Property Value
Model Architecture swin_tiny_patch4_window7_224 (Swin Transformer)
Base Pretraining ImageNet-1K (via timm)
Input Resolution 224×224 pixels
Input Channels 3 (RGB)
Number of Classes 6
Output Format Logits → Softmax probabilities
Framework PyTorch 2.0+
Precision FP16 (AMP) supported

Class Labels

labels = [
    "Bacterial Leaf Blight",
    "Brown Spot", 
    "Leaf Blast",
    "Sheath Blight",
    "Tungro",
    "Leaf Scald"
]

Training & Evaluation

Dataset Composition

The model was trained on a merged dataset from three Kaggle sources:

  1. Rice Disease Dataset
  2. Rice Leaf Disease Image
  3. Rice Leaf Diseases

Preprocessing:

  • Images resized/cropped to 224×224
  • Normalized with ImageNet statistics: mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
  • Augmentations: RandomResizedCrop, Horizontal/Vertical Flip, ColorJitter

Training Configuration

optimizer: AdamW (lr=1e-4, weight_decay=0.05)
scheduler: OneCycleLR (5% warmup, cosine annealing)
batch_size: 16
epochs: 15
loss: CrossEntropyLoss (label_smoothing=0.1)
mixed_precision: AMP (cuda)

Performance Metrics

Metric Value
Final Validation Accuracy 97.11%
Training Loss (final epoch) 0.500
Convergence Epoch ~12

metrics


How to Use

Install Dependencies

pip install torch torchvision timm albumentations pillow

Load & Run Inference (PyTorch)

import torch
import timm
from PIL import Image
from torchvision import transforms
import requests

# Load model
model = timm.create_model(
    "swin_tiny_patch4_window7_224",
    pretrained=False,
    num_classes=6
)
model.load_state_dict(
    torch.load("rice_model.pth", map_location="cpu", weights_only=True)
)
model.eval()

# Preprocessing
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Predict
def predict(image_path: str) -> dict:
    img = Image.open(image_path).convert("RGB")
    input_tensor = transform(img).unsqueeze(0)
    
    with torch.no_grad():
        logits = model(input_tensor)
        probs = torch.softmax(logits, dim=1)[0]
    
    labels = ["Bacterial Leaf Blight", "Brown Spot", "Leaf Blast", 
              "Sheath Blight", "Tungro", "Leaf Scald"]
    
    return {
        "prediction": labels[probs.argmax().item()],
        "confidence": probs.max().item(),
        "all_scores": dict(zip(labels, probs.tolist()))
    }

# Example
result = predict("sample_leaf.jpg")
print(f"{result['prediction']}: {result['confidence']:.2%}")

License

This model is released under the MIT License.
Datasets are subject to their original Kaggle licenses—please review source terms before commercial use.

Acknowledgments

  • Ross Wightman for the timm library and Swin Transformer implementations
  • Kaggle dataset contributors: @anshulm257, @nirmalsankalana, @vbookshelf
  • Hugging Face for the model hosting infrastructure
  • Google Gemini For assistance in code generation and analysis.

Model developed for agricultural AI research. Not intended for standalone commercial deployment without validation.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support