INVERTO's picture
Upload trained bird captioning model, tokenizer, image processor, species mapping, and captions
fbb6522 verified
---
language: en
license: mit
tags:
- vision
- image-captioning
- image-classification
- bird-species
datasets:
- cub-200-2011
---
# Bird Captioning and Classification Model (CUB-200-2011)
This is a fine-tuned VisionEncoderDecoderModel based on `nlpconnect/vit-gpt2-image-captioning`, trained on the CUB-200-2011 dataset for bird species classification and image captioning.
## Model Description
- **Base Model**: ViT-GPT2 (`nlpconnect/vit-gpt2-image-captioning`)
- **Tasks**:
- Generates descriptive captions for bird images, including species and attributes.
- Classifies images into one of 200 bird species.
- **Dataset**: CUB-200-2011 (11,788 images, 200 bird species)
- **Training**: 10 epochs, batch size 16, mixed precision, AdamW optimizer (lr=3e-5), combined loss (caption + 0.5 * classification).
- **Best Validation Loss**: 0.0690 (Epoch 3)
## Files
- `model.safetensors`: Trained model weights
- `config.json`: Model configuration
- `preprocessor_config.json`: ViTImageProcessor settings
- `tokenizer_config.json`, `vocab.json`: GPT2 tokenizer files
- `species_mapping.txt`: Mapping of class indices to bird species names
- `cub200_captions.csv`: Generated captions for the dataset
- `model.py`: Custom `BirdCaptioningModel` class definition
## Usage
### Prerequisites
```bash
pip install transformers torch huggingface_hub
```
### Load Model and Dependencies
```python
from transformers import ViTImageProcessor, AutoTokenizer
from huggingface_hub import PyTorchModelHubMixin
import torch
from model import BirdCaptioningModel # Save model.py locally
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load model
model = BirdCaptioningModel.from_pretrained("INVERTO/bird-captioning-cub200").to(device)
image_processor = ViTImageProcessor.from_pretrained("INVERTO/bird-captioning-cub200")
tokenizer = AutoTokenizer.from_pretrained("INVERTO/bird-captioning-cub200")
model.eval()
# Load species mapping
species_mapping = {}
with open("species_mapping.txt", "r") as f:
for line in f:
idx, name = line.strip().split(",", 1)
species_mapping[int(idx)] = name
```
### Inference
```python
from PIL import Image
def predict_bird_image(image_path):
image = Image.open(image_path).convert("RGB")
pixel_values = image_processor(image, return_tensors="pt").pixel_values.to(device)
with torch.no_grad():
output_ids = model.base_model.generate(pixel_values, max_length=75, num_beams=4)
_, class_logits = model(pixel_values)
predicted_class_idx = torch.argmax(class_logits, dim=1).item()
confidence = torch.nn.functional.softmax(class_logits, dim=1)[0, predicted_class_idx].item() * 100
caption = tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
species = species_mapping.get(predicted_class_idx, "Unknown")
return caption, species, confidence
# Example
caption, species, confidence = predict_bird_image("/kaggle/input/cub2002011/CUB_200_2011/images/006.Least_Auklet/Least_Auklet_0007_795123.jpg")
print(f"Caption: {caption}")
print(f"Species: {species}")
print(f"Confidence: {confidence:.2f}%")
```
## Dataset
- **CUB-200-2011**: 11,788 images of 200 bird species with attribute annotations.
- Captions were generated based on species names and attributes (e.g., bill shape, wing color).
## Training Details
- **Loss**: Combined captioning (CrossEntropy) and classification (CrossEntropy) loss.
- **Optimizer**: AdamW (lr=3e-5)
- **Scheduler**: CosineAnnealingLR
- **Hardware**: GPU (CUDA)
- **Training Time**: ~5 min/epoch
## Limitations
- May overfit after Epoch 3 (validation loss increases).
- Captions are limited to species and up to 5 attributes.
- Classification accuracy not explicitly reported.
## License
MIT License
## Contact
For issues, contact INVERTO on Hugging Face.