British Birds Vision Transformer (ViT)
This repository contains a fine-tuned Vision Transformer (ViT) model designed to classify images of 224 common British bird species.
The model was fine-tuned from the vit_base_patch16_224_imagenet21k backbone using Keras 3.
Dataset
The dataset comprises 224 classes with exactly 500 image samples per class (112,000 images total).
Images were sourced from eBird on February 28, 2026, using the Birdhouse CLI tool, strictly filtering for the Great Britain region (--region=GB), moderately rated images (--min-avg-rating=4), with a reasonable number of reviews (--min-reviews=2).
The data was split:
- Training: 80% (89,600 images)
- Validation: 10% (11,200 images)
- Test: 10% (11,200 images)
Model Performance
Evaluated on the 10% unseen Test Set:
- Final Test Accuracy: 85.76%
- Top-3 Accuracy: 94.03%
- Loss: 0.5567
Files included
model/final_224c_500i_GB_bird_vit.keras: The full Keras 3 model.model/224c_500i_GB_bird_vit_float32.tflite: Float32 TFLite model for mobile/edge deployment.model/224c_500i_GB_bird_vit_quantized.tflite: INT8 Quantized TFLite model for optimized edge deployment.model/224c_500i_GB_bird_classes.json: Ordered list of the 224 eBird taxon codes representing the classes.model/friendly_class_names.csv: Mapping of taxon codes to friendly human-readable bird names.training/british-birds-vit-training.ipynb: The Jupyter Notebook used to train the model.
Usage
The following snippet demonstrates how to download the model, process an image, and use the included CSV file to output a human-readable English bird name (e.g., "Arctic Skua") rather than the raw eBird taxon code (e.g., "parjae").
import keras
import numpy as np
import json
import csv
from huggingface_hub import hf_hub_download
REPO_ID = "rossheaton/british-birds-vit-base-patch16-224"
# Download model, class codes, and the friendly names CSV
model_path = hf_hub_download(repo_id=REPO_ID, filename="model/final_224c_500i_GB_bird_vit.keras")
labels_path = hf_hub_download(repo_id=REPO_ID, filename="model/224c_500i_GB_bird_classes.json")
csv_path = hf_hub_download(repo_id=REPO_ID, filename="model/friendly_class_names.csv")
# Load the Keras model
model = keras.models.load_model(model_path)
# Load taxon codes (the raw array of classes)
with open(labels_path, 'r') as f:
taxon_codes = json.load(f)
# Build a dictionary mapping taxon codes to friendly names
taxon_to_friendly = {}
with open(csv_path, mode='r', encoding='utf-8') as f:
reader = csv.DictReader(f)
for row in reader:
# Maps e.g. "parjae" -> "Arctic Skua - Stercorarius parasiticus"
# Note: Change 'ebird_search_term' to 'rspb_name' if you prefer just the short English name
taxon_to_friendly[row['ebird_taxon_code']] = row['ebird_search_term']
# Load and preprocess a local image
image_path = "path/to/bird/image.jpg"
img = keras.utils.load_img(image_path, target_size=(224, 224))
img_array = keras.utils.img_to_array(img)
img_array = np.expand_dims(img_array, axis=0)
# Predict
predictions = model.predict(img_array)
predicted_index = np.argmax(predictions)
confidence = predictions[0][predicted_index] * 100
# Get the raw taxon code, then look up the friendly name
predicted_taxon = taxon_codes[predicted_index]
friendly_name = taxon_to_friendly.get(predicted_taxon, predicted_taxon) # Fallback to taxon code if missing
print(f"Prediction: {friendly_name} ({confidence:.2f}% confidence)")
Training Details
- Optimiser: AdamW (Learning Rate: 1e-4, Weight Decay: 0.01)
- Loss: Sparse Categorical Crossentropy
- Epochs: 8 (Saved best weights based on Validation Accuracy from Epoch 7)
- Batch Size: 64
- Hardware: 1x NVIDIA RTX 4090 24 GB VRAM, 16 vCPUs, 62 GB RAM
- Downloads last month
- 88