vGiacomov commited on
Commit
e072d30
·
1 Parent(s): 5a216d7

Update model with improved regularization and data augmentation

Browse files
README.md CHANGED
@@ -7,25 +7,76 @@ tags:
7
  - pytorch
8
  - resnet
9
  - beans
 
 
10
  datasets:
11
  - beans
12
  library_name: pytorch
13
  pipeline_tag: image-classification
 
 
14
  ---
15
 
16
- # ResNet18 fine-tuned on Beans dataset
17
 
18
- This model was trained in Google Colab using a GPU and tracked with MLflow.
19
 
20
- Dataset: [Beans](https://huggingface.co/datasets/AI-Lab-Makerere/beans)
21
 
22
- ## Classes:
 
 
23
  - Healthy
24
  - Bean Rust
25
  - Angular Leaf Spot
26
 
27
- Validation Accuracy: 0.9774
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
- ## Artifacts:
30
- - per_class_metrics.csv
31
- - confusion_matrix.png
 
7
  - pytorch
8
  - resnet
9
  - beans
10
+ - agriculture
11
+ - plant-disease
12
  datasets:
13
  - beans
14
  library_name: pytorch
15
  pipeline_tag: image-classification
16
+ metrics:
17
+ - accuracy
18
  ---
19
 
20
+ # ResNet18 Fine-tuned on Beans Dataset
21
 
22
+ This model was trained in Google Colab using a T4 GPU and tracked with MLflow.
23
 
24
+ ## Model Details
25
 
26
+ **Dataset:** [Beans](https://huggingface.co/datasets/AI-Lab-Makerere/beans)
27
+
28
+ **Classes:**
29
  - Healthy
30
  - Bean Rust
31
  - Angular Leaf Spot
32
 
33
+ **Validation Accuracy:** 0.9173
34
+
35
+ ## Training Configuration
36
+
37
+ **Overfitting Prevention Techniques:**
38
+ - Data augmentation (rotation, flip, crop, color jitter)
39
+ - Dropout (30%)
40
+ - L2 regularization (weight decay: 1e-4)
41
+ - Learning rate scheduling (ReduceLROnPlateau)
42
+ - Best model selection based on validation accuracy
43
+
44
+ **Hyperparameters:**
45
+ - Learning Rate: 5e-05
46
+ - Epochs: 5
47
+ - Batch Size: 32
48
+ - Weight Decay: 0.0001
49
+ - Dropout: 0.3
50
+ - Optimizer: Adam
51
+
52
+ ## Artifacts
53
+
54
+ - `resnet18_beans.pth` - PyTorch model weights
55
+ - `per_class_metrics.csv` - Detailed per-class metrics
56
+ - `confusion_matrix.png` - Confusion matrix visualization
57
+
58
+ ## Usage
59
+
60
+ Download and load the model:
61
+
62
+ from huggingface_hub import hf_hub_download
63
+ import torch
64
+ from torchvision import models
65
+ from torch import nn
66
+
67
+ model_path = hf_hub_download(
68
+ repo_id="vGiacomov/image-classifier-beans",
69
+ filename="resnet18_beans.pth"
70
+ )
71
+
72
+ model = models.resnet18()
73
+ model.fc = nn.Sequential(
74
+ nn.Dropout(0.3),
75
+ nn.Linear(model.fc.in_features, 3)
76
+ )
77
+ model.load_state_dict(torch.load(model_path, map_location="cpu"))
78
+ model.eval()
79
+
80
+ ## License
81
 
82
+ Apache 2.0
 
 
confusion_matrix.png CHANGED
per_class_metrics.csv CHANGED
@@ -1,7 +1,7 @@
1
  ,precision,recall,f1-score,support
2
- 0,0.9772727272727273,0.9772727272727273,0.9772727272727273,44.0
3
- 1,0.9782608695652174,1.0,0.989010989010989,45.0
4
- 2,0.9767441860465116,0.9545454545454546,0.9655172413793104,44.0
5
- accuracy,0.9774436090225563,0.9774436090225563,0.9774436090225563,0.9774436090225563
6
- macro avg,0.9774259276281522,0.9772727272727272,0.9772669858876756,133.0
7
- weighted avg,0.9774322053870774,0.9774436090225563,0.9773552866630388,133.0
 
1
  ,precision,recall,f1-score,support
2
+ Healthy,1.0,0.7954545454545454,0.8860759493670886,44.0
3
+ Bean Rust,0.8035714285714286,1.0,0.8910891089108911,45.0
4
+ Angular Leaf Spot,1.0,0.9545454545454546,0.9767441860465116,44.0
5
+ accuracy,0.9172932330827067,0.9172932330827067,0.9172932330827067,0.9172932330827067
6
+ macro avg,0.9345238095238096,0.9166666666666666,0.9179697481081638,133.0
7
+ weighted avg,0.9335392051557464,0.9172932330827067,0.9177676380390113,133.0
resnet18_beans.pth CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:bc36a065da9c253480a10f50676af3e8e7e37b19a1c7b4ae39459542f290acc0
3
- size 44792331
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f2ba2b958dcf443e818257191c69f9cc804dc0df26934064e371c0f51fb4b5d2
3
+ size 44792395