File size: 2,829 Bytes
579ed70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
---
license: apache-2.0
language:
- en
tags:
- image-classification
- walnut
- defect-detection
- efficientnet
- timm
- pytorch
- surface-defect
- quality-control
pipeline_tag: image-classification
library_name: timm
base_model: timm/efficientnet_b3.ra2_in1k
metrics:
- accuracy
- f1
model-index:
- name: Walnut Shell Defect Classifier
  results:
  - task:
      type: image-classification
      name: Image Classification
    dataset:
      name: Nut Surface Defect Dataset (nutsv2ifolder_split)
      type: weihaoreal/nut-surface-defect-dataset
    metrics:
    - type: accuracy
      value: 0.9855
      name: Validation Accuracy
    - type: f1
      value: 0.98
      name: Macro F1
---

# Walnut Shell Defect Classifier

EfficientNet-B3 finetuned for walnut shell defect classification across 4 categories.
Trained on the Nut Surface Defect Dataset with class remapping to match walnut-specific defect taxonomy.

## Classes

| Output Label | Remapped From (Dataset) |
|---|---|
| Healthy | Excellent |
| Black Spot | Rusting |
| Shriveled | Scratches |
| Damaged | Deformation + Fracture |

## Metrics (Epoch 8 — Best Checkpoint)

| Class | Precision | Recall | F1 |
|---|---|---|---|
| Healthy | 0.88 | 1.00 | 0.93 |
| Black Spot | 1.00 | 0.99 | 1.00 |
| Shriveled | 1.00 | 0.98 | 0.99 |
| Damaged | 1.00 | 0.98 | 0.99 |
| **Macro Avg** | **0.97** | **0.99** | **0.98** |
| **Weighted Avg** | **0.99** | **0.99** | **0.99** |

**Val Accuracy: 98.55% | Macro F1: 0.98**

## Training Setup

| Parameter | Value |
|---|---|
| Base Model | EfficientNet-B3 (pretrained ImageNet) |
| Image Size | 512×512 px |
| Batch Size | 18 per GPU × 2 T4 = 36 effective |
| Optimizer | AdamW (lr=2e-5, wd=1e-2) |
| Scheduler | Cosine Annealing + 3-epoch warmup |
| Precision | FP16 (torch.cuda.amp) |
| Drop Rate | 0.4 |
| Label Smoothing | 0.05 |
| Early Stop Patience | 7 epochs |
| Hardware | Kaggle 2× NVIDIA T4 (16 GB each) |

## Inference

```python
import torch, timm
from PIL import Image
import torchvision.transforms as transforms

CLASSES = ["Healthy", "Black Spot", "Shriveled", "Damaged"]

model = timm.create_model("efficientnet_b3", pretrained=False,
                           num_classes=4, drop_rate=0.4)
ckpt  = torch.load("best_model.pth", map_location="cpu")
state = {k.replace("module.", ""): v for k, v in ckpt["model_state_dict"].items()}
model.load_state_dict(state)
model.eval()

transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

img    = Image.open("walnut.jpg").convert("RGB")
x      = transform(img).unsqueeze(0)
probs  = torch.softmax(model(x), dim=1)
conf, idx = probs.max(0)
print({"defect_class": CLASSES[idx.item()], "confidence": round(conf.item(), 4)})
```
## License

Apache 2.0