|
|
--- |
|
|
license: mit |
|
|
tags: |
|
|
- image-classification |
|
|
- pytorch |
|
|
- convnext |
|
|
- birds |
|
|
- computer-vision |
|
|
datasets: |
|
|
- CUB-200-2011 |
|
|
metrics: |
|
|
- accuracy |
|
|
library_name: pytorch |
|
|
--- |
|
|
|
|
|
# Bird Species Classification - ConvNeXt-Base |
|
|
|
|
|
## Model Description |
|
|
|
|
|
This model classifies 200 bird species using ConvNeXt-Base architecture with transfer learning. |
|
|
|
|
|
## Performance |
|
|
|
|
|
- **Test Accuracy**: 83.64% |
|
|
- **Average Per-Class Accuracy**: 83.29% |
|
|
- **Architecture**: ConvNeXt-Base (87M parameters) |
|
|
- **Dataset**: CUB-200-2011 (200 bird species) |
|
|
|
|
|
## Training Details |
|
|
|
|
|
### Model Architecture |
|
|
- **Base Model**: ConvNeXt-Base pretrained on ImageNet-1K |
|
|
- **Classifier**: Custom 2-layer classifier with dropout |
|
|
- **Input Size**: 224x224 RGB images |
|
|
|
|
|
### Training Strategy |
|
|
- **Phase 1** (40 epochs): Frozen backbone, train classifier only |
|
|
- Learning Rate: 0.001 |
|
|
- Batch Size: 32 |
|
|
|
|
|
- **Phase 2** (20 epochs): Full fine-tuning |
|
|
- Learning Rate: 0.0001 |
|
|
- Batch Size: 32 |
|
|
|
|
|
### Regularization |
|
|
- Dropout: 0.6, 0.5 |
|
|
- Label Smoothing: 0.2 |
|
|
- Weight Decay: 0.005 |
|
|
- Data Augmentation: rotation, flip, color jitter, random erasing |
|
|
|
|
|
## Usage |
|
|
|
|
|
```python |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torchvision import models, transforms |
|
|
from PIL import Image |
|
|
|
|
|
# Load model |
|
|
model = models.convnext_base(weights=None) |
|
|
num_features = model.classifier[2].in_features |
|
|
model.classifier[2] = nn.Sequential( |
|
|
nn.Dropout(0.6), |
|
|
nn.Linear(num_features, 1024), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(0.5), |
|
|
nn.Linear(1024, 200) |
|
|
) |
|
|
|
|
|
# Load weights |
|
|
model.load_state_dict(torch.load('final_model.pth', map_location='cpu')) |
|
|
model.eval() |
|
|
|
|
|
# Preprocessing |
|
|
transform = transforms.Compose([ |
|
|
transforms.Resize((224, 224)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
|
|
]) |
|
|
|
|
|
# Predict |
|
|
image = Image.open('bird.jpg').convert('RGB') |
|
|
image_tensor = transform(image).unsqueeze(0) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(image_tensor) |
|
|
probabilities = torch.nn.functional.softmax(outputs[0], dim=0) |
|
|
top5_prob, top5_indices = torch.topk(probabilities, 5) |
|
|
|
|
|
print("Top 5 Predictions:") |
|
|
for prob, idx in zip(top5_prob, top5_indices): |
|
|
print(f"Class {idx}: {prob.item()*100:.2f}%") |
|
|
``` |
|
|
|
|
|
## Try it out! |
|
|
|
|
|
Try the live demo: [Bird Species Classifier](https://huggingface.co/spaces/AshProg/AppliedMachineLearning_BirdClassifierInterface) |
|
|
|
|
|
## Model Files |
|
|
|
|
|
- `final_model.pth` (1.06 GB): Full model weights |
|
|
|
|
|
## Citation |
|
|
|
|
|
Dataset: CUB-200-2011 |
|
|
``` |
|
|
@techreport{WahCUB_200_2011, |
|
|
Title = {{The Caltech-UCSD Birds-200-2011 Dataset}}, |
|
|
Author = {Wah, C. and Branson, S. and Welinder, P. and Perona, P. and Belongie, S.}, |
|
|
Year = {2011}, |
|
|
Institution = {California Institute of Technology}, |
|
|
Number = {CNS-TR-2011-001} |
|
|
} |
|
|
``` |
|
|
|
|
|
## Contact |
|
|
|
|
|
For questions or issues, please open an issue on the Space repository. |
|
|
|