vGiacomov's picture
Update model with improved regularization and data augmentation
595271a
---
language:
- en
license: apache-2.0
tags:
- image-classification
- pytorch
- resnet
- beans
- agriculture
- plant-disease
datasets:
- beans
library_name: pytorch
pipeline_tag: image-classification
metrics:
- accuracy
---
# ResNet18 Fine-tuned on Beans Dataset
This model was trained in Google Colab using a T4 GPU and tracked with MLflow.
## Model Details
**Dataset:** [Beans](https://huggingface.co/datasets/AI-Lab-Makerere/beans)
**Classes:**
- Healthy
- Bean Rust
- Angular Leaf Spot
**Validation Accuracy:** 0.9398
## Training Configuration
**Overfitting Prevention Techniques:**
- Data augmentation (rotation, flip, crop, color jitter)
- Dropout (30%)
- L2 regularization (weight decay: 1e-4)
- Learning rate scheduling (ReduceLROnPlateau)
- Best model selection based on validation accuracy
**Hyperparameters:**
- Learning Rate: 5e-05
- Epochs: 10
- Batch Size: 32
- Weight Decay: 0.0001
- Dropout: 0.3
- Optimizer: Adam
## Artifacts
- `resnet18_beans.pth` - PyTorch model weights
- `per_class_metrics.csv` - Detailed per-class metrics
- `confusion_matrix.png` - Confusion matrix visualization
## Usage
Download and load the model:
from huggingface_hub import hf_hub_download
import torch
from torchvision import models
from torch import nn
model_path = hf_hub_download(
repo_id="vGiacomov/image-classifier-beans",
filename="resnet18_beans.pth"
)
model = models.resnet18()
model.fc = nn.Sequential(
nn.Dropout(0.3),
nn.Linear(model.fc.in_features, 3)
)
model.load_state_dict(torch.load(model_path, map_location="cpu"))
model.eval()
Apache 2.0