| # Bowerbird Individual Classifier (ResNet50) | |
| This repository contains the weights for a ResNet50-based individual ID classifier | |
| trained on a set of Spotted Bowerbird individuals, as part of a project of the Fusani Lab | |
| (see https://github.com/sarequi/Bowerbird-ID). | |
| - Base model: `torchvision.models.resnet50` with `ResNet50_Weights.DEFAULT` (ImageNet) | |
| - Input size: 512 × 512 RGB | |
| - Normalization: | |
| - mean = [0.485, 0.456, 0.406] | |
| - std = [0.229, 0.224, 0.225] | |
| - Checkpoint file: `Bbird_individual_classifier.pth` | |
| > This model is **not** generic. It is specific to the 16 individuals it was trained on. | |
| > `NUM_CLASSES` must match the number of bird IDs used during training, unless the model is re-trained. | |
| ## Usage | |
| ```python | |
| import torch | |
| from torchvision.models import resnet50, ResNet50_Weights | |
| from huggingface_hub import hf_hub_download | |
| repo_id = "sarequi/bowerbird-individual-classifier" | |
| # Download checkpoint | |
| ckpt_path = hf_hub_download( | |
| repo_id=repo_id, | |
| filename="Bbird_individual_classifier.pth", | |
| ) | |
| NUM_CLASSES = 16 # number of individuals used during training | |
| # Rebuild model architecture | |
| model = resnet50(weights=ResNet50_Weights.DEFAULT) | |
| num_ftrs = model.fc.in_features | |
| model.fc = torch.nn.Linear(num_ftrs, NUM_CLASSES) | |
| # Load weights | |
| state_dict = torch.load(ckpt_path, map_location="cpu") | |
| model.load_state_dict(state_dict) | |
| model.eval() |