Instructions to use TigranBoyakhchyan/plant-disease-classifier with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- timm
How to use TigranBoyakhchyan/plant-disease-classifier with timm:
import timm model = timm.create_model("hf_hub:TigranBoyakhchyan/plant-disease-classifier", pretrained=True) - Notebooks
- Google Colab
- Kaggle
Plant Disease Classifier
Species-invariant plant disease classification. Given an image of a plant leaf, predicts the disease regardless of the host plant species.
Model
- Backbone: ConvNeXt-Tiny (27.85M parameters)
- Input: 224x224 RGB images
- Output: 39 disease classes
- Training: Fine-tuned from ImageNet pretraining with aggressive augmentation (RandAugment + random erasing)
Results
| Metric | Value |
|---|---|
| Standard validation accuracy | 0.628 |
| Holdout (leave-one-species-out) validation accuracy | 0.254 |
The holdout split holds out specific (species, disease) pairs from training, so the model is evaluated on unseen species-disease combinations. This is a stronger test of whether the model learned disease features vs. species features.
Files
final_model.pthโ PyTorch checkpoint (model weights + metadata)labels.jsonโ class index mapping and metadata
Usage
Install dependencies and load:
import torch, timm
from torchvision import transforms
from PIL import Image
from huggingface_hub import hf_hub_download
path = hf_hub_download(repo_id="TigranBoyakhchyan/plant-disease-classifier", filename="final_model.pth")
ckpt = torch.load(path, map_location="cpu", weights_only=False)
model = timm.create_model(ckpt["backbone"], pretrained=False, num_classes=ckpt["num_classes"])
model.load_state_dict(ckpt["model_state_dict"])
model.eval()
idx_to_disease = {v: k for k, v in ckpt["disease_to_idx"].items()}
tfm = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(ckpt["img_size"]),
transforms.ToTensor(),
transforms.Normalize(ckpt["mean"], ckpt["std"]),
])
img = Image.open("leaf.jpg").convert("RGB")
x = tfm(img).unsqueeze(0)
with torch.no_grad():
logits = model(x)
pred = idx_to_disease[logits.argmax(1).item()]
print(pred)
- Downloads last month
- -