|
|
--- |
|
|
language: |
|
|
- en |
|
|
license: apache-2.0 |
|
|
tags: |
|
|
- image-classification |
|
|
- pytorch |
|
|
- resnet |
|
|
- beans |
|
|
- agriculture |
|
|
- plant-disease |
|
|
datasets: |
|
|
- beans |
|
|
library_name: pytorch |
|
|
pipeline_tag: image-classification |
|
|
metrics: |
|
|
- accuracy |
|
|
--- |
|
|
|
|
|
# ResNet18 Fine-tuned on Beans Dataset |
|
|
|
|
|
This model was trained in Google Colab using a T4 GPU and tracked with MLflow. |
|
|
|
|
|
## Model Details |
|
|
|
|
|
**Dataset:** [Beans](https://huggingface.co/datasets/AI-Lab-Makerere/beans) |
|
|
|
|
|
**Classes:** |
|
|
- Healthy |
|
|
- Bean Rust |
|
|
- Angular Leaf Spot |
|
|
|
|
|
**Validation Accuracy:** 0.9398 |
|
|
|
|
|
## Training Configuration |
|
|
|
|
|
**Overfitting Prevention Techniques:** |
|
|
- Data augmentation (rotation, flip, crop, color jitter) |
|
|
- Dropout (30%) |
|
|
- L2 regularization (weight decay: 1e-4) |
|
|
- Learning rate scheduling (ReduceLROnPlateau) |
|
|
- Best model selection based on validation accuracy |
|
|
|
|
|
**Hyperparameters:** |
|
|
- Learning Rate: 5e-05 |
|
|
- Epochs: 10 |
|
|
- Batch Size: 32 |
|
|
- Weight Decay: 0.0001 |
|
|
- Dropout: 0.3 |
|
|
- Optimizer: Adam |
|
|
|
|
|
## Artifacts |
|
|
|
|
|
- `resnet18_beans.pth` - PyTorch model weights |
|
|
- `per_class_metrics.csv` - Detailed per-class metrics |
|
|
- `confusion_matrix.png` - Confusion matrix visualization |
|
|
|
|
|
## Usage |
|
|
|
|
|
Download and load the model: |
|
|
|
|
|
from huggingface_hub import hf_hub_download |
|
|
import torch |
|
|
from torchvision import models |
|
|
from torch import nn |
|
|
|
|
|
model_path = hf_hub_download( |
|
|
repo_id="vGiacomov/image-classifier-beans", |
|
|
filename="resnet18_beans.pth" |
|
|
) |
|
|
|
|
|
model = models.resnet18() |
|
|
model.fc = nn.Sequential( |
|
|
nn.Dropout(0.3), |
|
|
nn.Linear(model.fc.in_features, 3) |
|
|
) |
|
|
model.load_state_dict(torch.load(model_path, map_location="cpu")) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
Apache 2.0 |
|
|
|