File size: 1,654 Bytes
728f1e6
 
 
 
 
 
 
 
 
e072d30
 
728f1e6
 
 
 
e072d30
 
728f1e6
ad35281
e072d30
ad35281
e072d30
ad35281
e072d30
b700c98
e072d30
 
 
ad35281
 
 
 
595271a
e072d30
 
 
 
 
 
 
 
 
 
 
c445f9e
3f66ae1
e072d30
 
595271a
e072d30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b700c98
e072d30
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
---
language:
- en
license: apache-2.0
tags:
- image-classification
- pytorch
- resnet
- beans
- agriculture
- plant-disease
datasets:
- beans
library_name: pytorch
pipeline_tag: image-classification
metrics:
- accuracy
---

# ResNet18 Fine-tuned on Beans Dataset

This model was trained in Google Colab using a T4 GPU and tracked with MLflow.

## Model Details

**Dataset:** [Beans](https://huggingface.co/datasets/AI-Lab-Makerere/beans)

**Classes:**
- Healthy
- Bean Rust
- Angular Leaf Spot

**Validation Accuracy:** 0.9398

## Training Configuration

**Overfitting Prevention Techniques:**
- Data augmentation (rotation, flip, crop, color jitter)
- Dropout (30%)
- L2 regularization (weight decay: 1e-4)
- Learning rate scheduling (ReduceLROnPlateau)
- Best model selection based on validation accuracy

**Hyperparameters:**
- Learning Rate: 5e-05
- Epochs: 10
- Batch Size: 32
- Weight Decay: 0.0001
- Dropout: 0.3
- Optimizer: Adam

## Artifacts

- `resnet18_beans.pth` - PyTorch model weights
- `per_class_metrics.csv` - Detailed per-class metrics
- `confusion_matrix.png` - Confusion matrix visualization

## Usage

Download and load the model:

    from huggingface_hub import hf_hub_download
    import torch
    from torchvision import models
    from torch import nn
    
    model_path = hf_hub_download(
        repo_id="vGiacomov/image-classifier-beans",
        filename="resnet18_beans.pth"
    )
    
    model = models.resnet18()
    model.fc = nn.Sequential(
        nn.Dropout(0.3),
        nn.Linear(model.fc.in_features, 3)
    )
    model.load_state_dict(torch.load(model_path, map_location="cpu"))
    model.eval()


Apache 2.0