File size: 1,654 Bytes
728f1e6 e072d30 728f1e6 e072d30 728f1e6 ad35281 e072d30 ad35281 e072d30 ad35281 e072d30 b700c98 e072d30 ad35281 595271a e072d30 c445f9e 3f66ae1 e072d30 595271a e072d30 b700c98 e072d30 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
---
language:
- en
license: apache-2.0
tags:
- image-classification
- pytorch
- resnet
- beans
- agriculture
- plant-disease
datasets:
- beans
library_name: pytorch
pipeline_tag: image-classification
metrics:
- accuracy
---
# ResNet18 Fine-tuned on Beans Dataset
This model was trained in Google Colab using a T4 GPU and tracked with MLflow.
## Model Details
**Dataset:** [Beans](https://huggingface.co/datasets/AI-Lab-Makerere/beans)
**Classes:**
- Healthy
- Bean Rust
- Angular Leaf Spot
**Validation Accuracy:** 0.9398
## Training Configuration
**Overfitting Prevention Techniques:**
- Data augmentation (rotation, flip, crop, color jitter)
- Dropout (30%)
- L2 regularization (weight decay: 1e-4)
- Learning rate scheduling (ReduceLROnPlateau)
- Best model selection based on validation accuracy
**Hyperparameters:**
- Learning Rate: 5e-05
- Epochs: 10
- Batch Size: 32
- Weight Decay: 0.0001
- Dropout: 0.3
- Optimizer: Adam
## Artifacts
- `resnet18_beans.pth` - PyTorch model weights
- `per_class_metrics.csv` - Detailed per-class metrics
- `confusion_matrix.png` - Confusion matrix visualization
## Usage
Download and load the model:
from huggingface_hub import hf_hub_download
import torch
from torchvision import models
from torch import nn
model_path = hf_hub_download(
repo_id="vGiacomov/image-classifier-beans",
filename="resnet18_beans.pth"
)
model = models.resnet18()
model.fc = nn.Sequential(
nn.Dropout(0.3),
nn.Linear(model.fc.in_features, 3)
)
model.load_state_dict(torch.load(model_path, map_location="cpu"))
model.eval()
Apache 2.0
|