--- language: - en license: apache-2.0 tags: - image-classification - pytorch - resnet - beans - agriculture - plant-disease datasets: - beans library_name: pytorch pipeline_tag: image-classification metrics: - accuracy --- # ResNet18 Fine-tuned on Beans Dataset This model was trained in Google Colab using a T4 GPU and tracked with MLflow. ## Model Details **Dataset:** [Beans](https://huggingface.co/datasets/AI-Lab-Makerere/beans) **Classes:** - Healthy - Bean Rust - Angular Leaf Spot **Validation Accuracy:** 0.9398 ## Training Configuration **Overfitting Prevention Techniques:** - Data augmentation (rotation, flip, crop, color jitter) - Dropout (30%) - L2 regularization (weight decay: 1e-4) - Learning rate scheduling (ReduceLROnPlateau) - Best model selection based on validation accuracy **Hyperparameters:** - Learning Rate: 5e-05 - Epochs: 10 - Batch Size: 32 - Weight Decay: 0.0001 - Dropout: 0.3 - Optimizer: Adam ## Artifacts - `resnet18_beans.pth` - PyTorch model weights - `per_class_metrics.csv` - Detailed per-class metrics - `confusion_matrix.png` - Confusion matrix visualization ## Usage Download and load the model: from huggingface_hub import hf_hub_download import torch from torchvision import models from torch import nn model_path = hf_hub_download( repo_id="vGiacomov/image-classifier-beans", filename="resnet18_beans.pth" ) model = models.resnet18() model.fc = nn.Sequential( nn.Dropout(0.3), nn.Linear(model.fc.in_features, 3) ) model.load_state_dict(torch.load(model_path, map_location="cpu")) model.eval() Apache 2.0