File size: 5,292 Bytes
ab97c92 bef5c47 8c58ce5 ab97c92 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 | ---
license: mit
tags:
- image-classification
- birds
- resnet
- pytorch
- wildlife
datasets:
- nabirds
- birdsnap
- inaturalist
pipeline_tag: image-classification
---
# Bird Species Classifier (ResNet-50)
Fine-tuned ResNet-50 models for classifying North American bird species from cropped bird photographs.
## Model Description
These models are ResNet-50 backbones pretrained on ImageNet V2, fine-tuned on the [NABirds](https://dl.allawnmilner.com/nabirds) dataset augmented with [Birdsnap](https://thomasberg.org/) and [iNaturalist](https://www.inaturalist.org/) data. They are designed for use in a photography processing pipeline that first detects birds with YOLO, crops them at full resolution, then classifies the crop.
### Architecture
- **Backbone**: ResNet-50 (ImageNet V2 pretrained)
- **Pooling**: Generalized Mean (GeM) pooling
- **Head**: `Sequential(Dropout(0.4), Linear(2048, num_classes))`
- **Input size**: 240x240 pixels, normalized with ImageNet mean/std
- **Preprocessing**: `ToTensor()` + `Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))`
### Training Strategy
Three-stage progressive unfreezing:
| Stage | Unfrozen Layers | Purpose |
|-------|-----------------|---------|
| 1 | FC head only | Learn species mapping on frozen backbone features |
| 2 | `layer4` + FC | Adapt high-level features |
| 3 | `layer3` + `layer4` + FC | Fine-tune mid-level features |
Training was conducted using an automated research loop (Codex-driven) with 2-hour time budgets per experiment for the 98-species model and 4-10 hour budgets for the 404-species model.
## Available Checkpoints
### `subset98_combined/best.pt` — 98 Target Species
| Metric | Value |
|--------|-------|
| Top-1 Test Accuracy | **97.4%** |
| Top-1 Val Accuracy | 97.6% |
| Classes | 98 target species |
| Training Data | NABirds + Birdsnap + iNaturalist (~38K training images) |
| Total Epochs | 12 |
| Training Time | 2 hours |
| Peak Memory | 589 MB |
| File Size | ~91 MB |
Best run: `20260319_074647_c9dbe6` — stage3 cap=6 + layer2 lr=1.5e-5
### `base_combined/best.pt` — 404 Base Species
| Metric | Value |
|--------|-------|
| Top-1 Test Accuracy | **93.6%** |
| Top-1 Val Accuracy | 93.6% |
| Classes | 404 NABirds base species (sex/morph variants collapsed) |
| Training Data | NABirds + Birdsnap + iNaturalist (~166K training images) |
| Total Epochs | 20 |
| Training Time | ~9.6 hours |
| Peak Memory | 898 MB |
| Batch Size | 128 |
| File Size | ~98 MB |
Best run: `20260319_234135_b8fe6e` — bs=128 + stage lrs 3e-4/6e-5
## Usage
### With the Bird Photography Pipeline
```bash
git clone --branch MVP https://github.com/rkutyna/BirdBrained
cd BirdBrained
pip install -r requirements.txt
python download_models.py
streamlit run frontend/bird_gallery_frontend.py
```
### Standalone Inference (PyTorch)
```python
import torch
from torchvision import models, transforms
from PIL import Image
# Load checkpoint
state_dict = torch.load("subset98_combined/best.pt", map_location="cpu")
# Build model
model = models.resnet50()
model.fc = torch.nn.Sequential(
torch.nn.Dropout(p=0.4),
torch.nn.Linear(model.fc.in_features, 98), # or 404 for base_combined
)
model.load_state_dict(state_dict)
model.eval()
# Preprocess a cropped bird image
transform = transforms.Compose([
transforms.Resize((240, 240)),
transforms.ToTensor(),
transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])
img = Image.open("bird_crop.jpg").convert("RGB")
input_tensor = transform(img).unsqueeze(0)
with torch.no_grad():
logits = model(input_tensor)
probs = torch.softmax(logits, dim=1)
top5_probs, top5_indices = probs.topk(5)
```
Label names are provided in the repository as CSV files:
- `label_names.csv` — 98 target species
- `label_names_nabirds_base_species.csv` — 404 base species
## Training Data
| Dataset | Images | Species | Role |
|---------|--------|---------|------|
| [NABirds](https://dl.allawnmilner.com/nabirds) | ~48K | 555 specific / 404 base | Train + Val + Test |
| [Birdsnap](https://thomasberg.org/) | ~50K | ~335 matched | Train only |
| [iNaturalist](https://www.inaturalist.org/) | ~70K | up to 280/species | Train only |
Validation and test splits use NABirds data only (no external data leakage).
## Limitations
- Trained on North American bird species only (NABirds taxonomy).
- Expects **cropped bird images** as input — not full scene photos. Use a bird detector (e.g., YOLO) to crop first.
- The 98-species model covers only a curated subset; out-of-distribution species will be misclassified into the nearest known class.
- Performance may degrade on heavily backlit, motion-blurred, or partially occluded subjects.
## Citation
If you use these models, please cite the NABirds dataset:
```bibtex
@inproceedings{van2015building,
title={Building a bird recognition app and large scale dataset with citizen scientists: The fine print in fine-grained dataset collection},
author={Van Horn, Grant and Branson, Steve and Farrell, Ryan and Haber, Scott and Barry, Jessie and Ipeirotis, Panos and Perona, Pietro and Belongie, Serge},
booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
pages={595--604},
year={2015}
}
``` |