SaffronVerify β ConvNeXt-Base (Pretrained)
A high-accuracy saffron quality classification model trained on the Arko007/saffron-verify dataset. The model classifies saffron images into three grades: Mogra, Lacha, and Adulterated.
Model Performance
Best checkpoint saved at Epoch 13 with early stopping triggered at Epoch 20.
| Metric | Value |
|---|---|
| Macro F1 | 0.9888 |
| Accuracy | 98.96% |
| Val Loss | 0.3562 |
Per-Class Results (Epoch 13 β Best Checkpoint)
| Class | Precision | Recall | F1-Score | Support |
|---|---|---|---|---|
| mogra | 0.98 | 0.98 | 0.98 | 56 |
| lacha | 0.98 | 0.98 | 0.98 | 64 |
| adulterated | 1.00 | 1.00 | 1.00 | 72 |
| macro avg | 0.99 | 0.99 | 0.99 | 192 |
Training Details
| Parameter | Value |
|---|---|
| Base Model | convnext_base (ImageNet-21k pretrained via timm) |
| Image Size | 512 Γ 512 |
| Effective Batch Size | 96 (16 per GPU Γ 2 GPUs Γ 3 grad accum) |
| Optimizer | AdamW (Ξ²β=0.9, Ξ²β=0.999) |
| Learning Rate | 5e-6 (backbone) / 2.5e-5 (head) |
| Scheduler | Warmup (5 epochs) + Cosine Annealing |
| Regularization | Drop rate 0.3, Drop path 0.2, Label smoothing 0.1 |
| Augmentation | Mixup (Ξ±=0.4) + CutMix (Ξ±=1.0) |
| AMP | float16 |
| Hardware | 2Γ NVIDIA Tesla T4 (DDP) |
| Best Epoch | 13 / 50 |
| Early Stopping | Patience 7 β triggered at Epoch 20 |
Training Progression
| Epoch | Val Loss | Accuracy | Macro F1 |
|---|---|---|---|
| 1 | 1.0631 | 51.04% | 0.5088 |
| 2 | 0.9541 | 71.88% | 0.7154 |
| 3 | 0.8096 | 81.77% | 0.8118 |
| 5 | 0.5122 | 90.62% | 0.9033 |
| 7 | 0.4153 | 95.31% | 0.9506 |
| 10 | 0.3676 | 97.92% | 0.9777 |
| 13 | 0.3562 | 98.96% | 0.9888 |
| 20 | β | β | β (early stop) |
Offline Data Augmentation
Training data was augmented offline from 167 real images to 3840 balanced training images (1280 per class) using a heavy Albumentations pipeline including random crops, flips, rotations, colour jitter, blur, noise, elastic transforms, perspective distortion, CoarseDropout, and CLAHE. Validation set was augmented from 41 real images to 192 balanced images (64 per class).
Usage
import torch
import timm
import torch.nn as nn
from torchvision import transforms
from PIL import Image
class SaffronVerifyModel(nn.Module):
def __init__(self):
super().__init__()
self.backbone = timm.create_model(
"convnext_base", pretrained=False,
num_classes=0, drop_rate=0.3, drop_path_rate=0.2
)
feat_dim = self.backbone.num_features
self.head = nn.Sequential(
nn.LayerNorm(feat_dim),
nn.Dropout(p=0.3),
nn.Linear(feat_dim, 512),
nn.GELU(),
nn.Dropout(p=0.15),
nn.Linear(512, 3),
)
def forward(self, x):
return self.head(self.backbone(x))
CLASSES = ["mogra", "lacha", "adulterated"]
# Load model
model = SaffronVerifyModel()
ckpt = torch.load("best_model.pth", map_location="cpu")
model.load_state_dict(ckpt["model_state"])
model.eval()
# Preprocess
transform = transforms.Compose([
transforms.Resize(512),
transforms.CenterCrop(512),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225]),
])
img = Image.open("saffron.jpg").convert("RGB")
tensor = transform(img).unsqueeze(0)
with torch.no_grad():
logits = model(tensor)
pred = logits.argmax(1).item()
print(f"Predicted class: {CLASSES[pred]}")
Dataset
- Source: Arko007/saffron-verify
- Raw train: 64 mogra + 64 lacha + 39 adulterated = 167 images
- Raw val: 16 mogra + 16 lacha + 9 adulterated = 41 images
- Augmented train: 3840 (balanced, 1280/class)
- Augmented val: 192 (balanced, 64/class)
License
Apache 2.0
- Downloads last month
- -
Model tree for Arko007/saffron-verify-pretrained
Base model
facebook/convnext-base-224-22k-1kDataset used to train Arko007/saffron-verify-pretrained
Space using Arko007/saffron-verify-pretrained 1
Evaluation results
- Macro F1 (best checkpoint) on saffron-verifyself-reported0.989
- Accuracy (best checkpoint) on saffron-verifyself-reported0.990