sarequi's picture
Create README.md
e05580b verified
# Bowerbird Individual Classifier (ResNet50)
This repository contains the weights for a ResNet50-based individual ID classifier
trained on a set of Spotted Bowerbird individuals, as part of a project of the Fusani Lab
(see https://github.com/sarequi/Bowerbird-ID).
- Base model: `torchvision.models.resnet50` with `ResNet50_Weights.DEFAULT` (ImageNet)
- Input size: 512 × 512 RGB
- Normalization:
- mean = [0.485, 0.456, 0.406]
- std = [0.229, 0.224, 0.225]
- Checkpoint file: `Bbird_individual_classifier.pth`
> This model is **not** generic. It is specific to the 16 individuals it was trained on.
> `NUM_CLASSES` must match the number of bird IDs used during training, unless the model is re-trained.
## Usage
```python
import torch
from torchvision.models import resnet50, ResNet50_Weights
from huggingface_hub import hf_hub_download
repo_id = "sarequi/bowerbird-individual-classifier"
# Download checkpoint
ckpt_path = hf_hub_download(
repo_id=repo_id,
filename="Bbird_individual_classifier.pth",
)
NUM_CLASSES = 16 # number of individuals used during training
# Rebuild model architecture
model = resnet50(weights=ResNet50_Weights.DEFAULT)
num_ftrs = model.fc.in_features
model.fc = torch.nn.Linear(num_ftrs, NUM_CLASSES)
# Load weights
state_dict = torch.load(ckpt_path, map_location="cpu")
model.load_state_dict(state_dict)
model.eval()