--- license: apache-2.0 language: - en tags: - image-classification - walnut - defect-detection - efficientnet - timm - pytorch - surface-defect - quality-control pipeline_tag: image-classification library_name: timm base_model: timm/efficientnet_b3.ra2_in1k metrics: - accuracy - f1 model-index: - name: Walnut Shell Defect Classifier results: - task: type: image-classification name: Image Classification dataset: name: Nut Surface Defect Dataset (nutsv2ifolder_split) type: weihaoreal/nut-surface-defect-dataset metrics: - type: accuracy value: 0.9855 name: Validation Accuracy - type: f1 value: 0.98 name: Macro F1 --- # Walnut Shell Defect Classifier EfficientNet-B3 finetuned for walnut shell defect classification across 4 categories. Trained on the Nut Surface Defect Dataset with class remapping to match walnut-specific defect taxonomy. ## Classes | Output Label | Remapped From (Dataset) | |---|---| | Healthy | Excellent | | Black Spot | Rusting | | Shriveled | Scratches | | Damaged | Deformation + Fracture | ## Metrics (Epoch 8 — Best Checkpoint) | Class | Precision | Recall | F1 | |---|---|---|---| | Healthy | 0.88 | 1.00 | 0.93 | | Black Spot | 1.00 | 0.99 | 1.00 | | Shriveled | 1.00 | 0.98 | 0.99 | | Damaged | 1.00 | 0.98 | 0.99 | | **Macro Avg** | **0.97** | **0.99** | **0.98** | | **Weighted Avg** | **0.99** | **0.99** | **0.99** | **Val Accuracy: 98.55% | Macro F1: 0.98** ## Training Setup | Parameter | Value | |---|---| | Base Model | EfficientNet-B3 (pretrained ImageNet) | | Image Size | 512×512 px | | Batch Size | 18 per GPU × 2 T4 = 36 effective | | Optimizer | AdamW (lr=2e-5, wd=1e-2) | | Scheduler | Cosine Annealing + 3-epoch warmup | | Precision | FP16 (torch.cuda.amp) | | Drop Rate | 0.4 | | Label Smoothing | 0.05 | | Early Stop Patience | 7 epochs | | Hardware | Kaggle 2× NVIDIA T4 (16 GB each) | ## Inference ```python import torch, timm from PIL import Image import torchvision.transforms as transforms CLASSES = ["Healthy", "Black Spot", "Shriveled", "Damaged"] model = timm.create_model("efficientnet_b3", pretrained=False, num_classes=4, drop_rate=0.4) ckpt = torch.load("best_model.pth", map_location="cpu") state = {k.replace("module.", ""): v for k, v in ckpt["model_state_dict"].items()} model.load_state_dict(state) model.eval() transform = transforms.Compose([ transforms.Resize((512, 512)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) img = Image.open("walnut.jpg").convert("RGB") x = transform(img).unsqueeze(0) probs = torch.softmax(model(x), dim=1) conf, idx = probs.max(0) print({"defect_class": CLASSES[idx.item()], "confidence": round(conf.item(), 4)}) ``` ## License Apache 2.0