File size: 5,292 Bytes
ab97c92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bef5c47
8c58ce5
ab97c92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
---
license: mit
tags:
  - image-classification
  - birds
  - resnet
  - pytorch
  - wildlife
datasets:
  - nabirds
  - birdsnap
  - inaturalist
pipeline_tag: image-classification
---

# Bird Species Classifier (ResNet-50)

Fine-tuned ResNet-50 models for classifying North American bird species from cropped bird photographs.

## Model Description

These models are ResNet-50 backbones pretrained on ImageNet V2, fine-tuned on the [NABirds](https://dl.allawnmilner.com/nabirds) dataset augmented with [Birdsnap](https://thomasberg.org/) and [iNaturalist](https://www.inaturalist.org/) data. They are designed for use in a photography processing pipeline that first detects birds with YOLO, crops them at full resolution, then classifies the crop.

### Architecture

- **Backbone**: ResNet-50 (ImageNet V2 pretrained)
- **Pooling**: Generalized Mean (GeM) pooling
- **Head**: `Sequential(Dropout(0.4), Linear(2048, num_classes))`
- **Input size**: 240x240 pixels, normalized with ImageNet mean/std
- **Preprocessing**: `ToTensor()` + `Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))`

### Training Strategy

Three-stage progressive unfreezing:

| Stage | Unfrozen Layers | Purpose |
|-------|-----------------|---------|
| 1 | FC head only | Learn species mapping on frozen backbone features |
| 2 | `layer4` + FC | Adapt high-level features |
| 3 | `layer3` + `layer4` + FC | Fine-tune mid-level features |

Training was conducted using an automated research loop (Codex-driven) with 2-hour time budgets per experiment for the 98-species model and 4-10 hour budgets for the 404-species model.

## Available Checkpoints

### `subset98_combined/best.pt` — 98 Target Species

| Metric | Value |
|--------|-------|
| Top-1 Test Accuracy | **97.4%** |
| Top-1 Val Accuracy | 97.6% |
| Classes | 98 target species |
| Training Data | NABirds + Birdsnap + iNaturalist (~38K training images) |
| Total Epochs | 12 |
| Training Time | 2 hours |
| Peak Memory | 589 MB |
| File Size | ~91 MB |

Best run: `20260319_074647_c9dbe6` — stage3 cap=6 + layer2 lr=1.5e-5

### `base_combined/best.pt` — 404 Base Species

| Metric | Value |
|--------|-------|
| Top-1 Test Accuracy | **93.6%** |
| Top-1 Val Accuracy | 93.6% |
| Classes | 404 NABirds base species (sex/morph variants collapsed) |
| Training Data | NABirds + Birdsnap + iNaturalist (~166K training images) |
| Total Epochs | 20 |
| Training Time | ~9.6 hours |
| Peak Memory | 898 MB |
| Batch Size | 128 |
| File Size | ~98 MB |

Best run: `20260319_234135_b8fe6e` — bs=128 + stage lrs 3e-4/6e-5

## Usage

### With the Bird Photography Pipeline

```bash
git clone --branch MVP https://github.com/rkutyna/BirdBrained
cd BirdBrained
pip install -r requirements.txt
python download_models.py
streamlit run frontend/bird_gallery_frontend.py
```

### Standalone Inference (PyTorch)

```python
import torch
from torchvision import models, transforms
from PIL import Image

# Load checkpoint
state_dict = torch.load("subset98_combined/best.pt", map_location="cpu")

# Build model
model = models.resnet50()
model.fc = torch.nn.Sequential(
    torch.nn.Dropout(p=0.4),
    torch.nn.Linear(model.fc.in_features, 98),  # or 404 for base_combined
)
model.load_state_dict(state_dict)
model.eval()

# Preprocess a cropped bird image
transform = transforms.Compose([
    transforms.Resize((240, 240)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])

img = Image.open("bird_crop.jpg").convert("RGB")
input_tensor = transform(img).unsqueeze(0)

with torch.no_grad():
    logits = model(input_tensor)
    probs = torch.softmax(logits, dim=1)
    top5_probs, top5_indices = probs.topk(5)
```

Label names are provided in the repository as CSV files:
- `label_names.csv` — 98 target species
- `label_names_nabirds_base_species.csv` — 404 base species

## Training Data

| Dataset | Images | Species | Role |
|---------|--------|---------|------|
| [NABirds](https://dl.allawnmilner.com/nabirds) | ~48K | 555 specific / 404 base | Train + Val + Test |
| [Birdsnap](https://thomasberg.org/) | ~50K | ~335 matched | Train only |
| [iNaturalist](https://www.inaturalist.org/) | ~70K | up to 280/species | Train only |

Validation and test splits use NABirds data only (no external data leakage).

## Limitations

- Trained on North American bird species only (NABirds taxonomy).
- Expects **cropped bird images** as input — not full scene photos. Use a bird detector (e.g., YOLO) to crop first.
- The 98-species model covers only a curated subset; out-of-distribution species will be misclassified into the nearest known class.
- Performance may degrade on heavily backlit, motion-blurred, or partially occluded subjects.

## Citation

If you use these models, please cite the NABirds dataset:

```bibtex
@inproceedings{van2015building,
  title={Building a bird recognition app and large scale dataset with citizen scientists: The fine print in fine-grained dataset collection},
  author={Van Horn, Grant and Branson, Steve and Farrell, Ryan and Haber, Scott and Barry, Jessie and Ipeirotis, Panos and Perona, Pietro and Belongie, Serge},
  booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
  pages={595--604},
  year={2015}
}
```