Arko007's picture
Create README.md
579ed70 verified
---
license: apache-2.0
language:
- en
tags:
- image-classification
- walnut
- defect-detection
- efficientnet
- timm
- pytorch
- surface-defect
- quality-control
pipeline_tag: image-classification
library_name: timm
base_model: timm/efficientnet_b3.ra2_in1k
metrics:
- accuracy
- f1
model-index:
- name: Walnut Shell Defect Classifier
results:
- task:
type: image-classification
name: Image Classification
dataset:
name: Nut Surface Defect Dataset (nutsv2ifolder_split)
type: weihaoreal/nut-surface-defect-dataset
metrics:
- type: accuracy
value: 0.9855
name: Validation Accuracy
- type: f1
value: 0.98
name: Macro F1
---
# Walnut Shell Defect Classifier
EfficientNet-B3 finetuned for walnut shell defect classification across 4 categories.
Trained on the Nut Surface Defect Dataset with class remapping to match walnut-specific defect taxonomy.
## Classes
| Output Label | Remapped From (Dataset) |
|---|---|
| Healthy | Excellent |
| Black Spot | Rusting |
| Shriveled | Scratches |
| Damaged | Deformation + Fracture |
## Metrics (Epoch 8 — Best Checkpoint)
| Class | Precision | Recall | F1 |
|---|---|---|---|
| Healthy | 0.88 | 1.00 | 0.93 |
| Black Spot | 1.00 | 0.99 | 1.00 |
| Shriveled | 1.00 | 0.98 | 0.99 |
| Damaged | 1.00 | 0.98 | 0.99 |
| **Macro Avg** | **0.97** | **0.99** | **0.98** |
| **Weighted Avg** | **0.99** | **0.99** | **0.99** |
**Val Accuracy: 98.55% | Macro F1: 0.98**
## Training Setup
| Parameter | Value |
|---|---|
| Base Model | EfficientNet-B3 (pretrained ImageNet) |
| Image Size | 512×512 px |
| Batch Size | 18 per GPU × 2 T4 = 36 effective |
| Optimizer | AdamW (lr=2e-5, wd=1e-2) |
| Scheduler | Cosine Annealing + 3-epoch warmup |
| Precision | FP16 (torch.cuda.amp) |
| Drop Rate | 0.4 |
| Label Smoothing | 0.05 |
| Early Stop Patience | 7 epochs |
| Hardware | Kaggle 2× NVIDIA T4 (16 GB each) |
## Inference
```python
import torch, timm
from PIL import Image
import torchvision.transforms as transforms
CLASSES = ["Healthy", "Black Spot", "Shriveled", "Damaged"]
model = timm.create_model("efficientnet_b3", pretrained=False,
num_classes=4, drop_rate=0.4)
ckpt = torch.load("best_model.pth", map_location="cpu")
state = {k.replace("module.", ""): v for k, v in ckpt["model_state_dict"].items()}
model.load_state_dict(state)
model.eval()
transform = transforms.Compose([
transforms.Resize((512, 512)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
img = Image.open("walnut.jpg").convert("RGB")
x = transform(img).unsqueeze(0)
probs = torch.softmax(model(x), dim=1)
conf, idx = probs.max(0)
print({"defect_class": CLASSES[idx.item()], "confidence": round(conf.item(), 4)})
```
## License
Apache 2.0