|
|
|
|
|
--- |
|
|
language: en |
|
|
license: mit |
|
|
tags: |
|
|
- vision |
|
|
- image-captioning |
|
|
- image-classification |
|
|
- bird-species |
|
|
datasets: |
|
|
- cub-200-2011 |
|
|
--- |
|
|
|
|
|
# Bird Captioning and Classification Model (CUB-200-2011) |
|
|
|
|
|
This is a fine-tuned VisionEncoderDecoderModel based on `nlpconnect/vit-gpt2-image-captioning`, trained on the CUB-200-2011 dataset for bird species classification and image captioning. |
|
|
|
|
|
## Model Description |
|
|
- **Base Model**: ViT-GPT2 (`nlpconnect/vit-gpt2-image-captioning`) |
|
|
- **Tasks**: |
|
|
- Generates descriptive captions for bird images, including species and attributes. |
|
|
- Classifies images into one of 200 bird species. |
|
|
- **Dataset**: CUB-200-2011 (11,788 images, 200 bird species) |
|
|
- **Training**: 10 epochs, batch size 16, mixed precision, AdamW optimizer (lr=3e-5), combined loss (caption + 0.5 * classification). |
|
|
- **Best Validation Loss**: 0.0690 (Epoch 3) |
|
|
|
|
|
## Files |
|
|
- `model.safetensors`: Trained model weights |
|
|
- `config.json`: Model configuration |
|
|
- `preprocessor_config.json`: ViTImageProcessor settings |
|
|
- `tokenizer_config.json`, `vocab.json`: GPT2 tokenizer files |
|
|
- `species_mapping.txt`: Mapping of class indices to bird species names |
|
|
- `cub200_captions.csv`: Generated captions for the dataset |
|
|
- `model.py`: Custom `BirdCaptioningModel` class definition |
|
|
|
|
|
## Usage |
|
|
### Prerequisites |
|
|
```bash |
|
|
pip install transformers torch huggingface_hub |
|
|
``` |
|
|
|
|
|
### Load Model and Dependencies |
|
|
```python |
|
|
from transformers import ViTImageProcessor, AutoTokenizer |
|
|
from huggingface_hub import PyTorchModelHubMixin |
|
|
import torch |
|
|
from model import BirdCaptioningModel # Save model.py locally |
|
|
|
|
|
# Set device |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
# Load model |
|
|
model = BirdCaptioningModel.from_pretrained("INVERTO/bird-captioning-cub200").to(device) |
|
|
image_processor = ViTImageProcessor.from_pretrained("INVERTO/bird-captioning-cub200") |
|
|
tokenizer = AutoTokenizer.from_pretrained("INVERTO/bird-captioning-cub200") |
|
|
model.eval() |
|
|
|
|
|
# Load species mapping |
|
|
species_mapping = {} |
|
|
with open("species_mapping.txt", "r") as f: |
|
|
for line in f: |
|
|
idx, name = line.strip().split(",", 1) |
|
|
species_mapping[int(idx)] = name |
|
|
``` |
|
|
|
|
|
### Inference |
|
|
```python |
|
|
from PIL import Image |
|
|
|
|
|
def predict_bird_image(image_path): |
|
|
image = Image.open(image_path).convert("RGB") |
|
|
pixel_values = image_processor(image, return_tensors="pt").pixel_values.to(device) |
|
|
with torch.no_grad(): |
|
|
output_ids = model.base_model.generate(pixel_values, max_length=75, num_beams=4) |
|
|
_, class_logits = model(pixel_values) |
|
|
predicted_class_idx = torch.argmax(class_logits, dim=1).item() |
|
|
confidence = torch.nn.functional.softmax(class_logits, dim=1)[0, predicted_class_idx].item() * 100 |
|
|
caption = tokenizer.decode(output_ids[0], skip_special_tokens=True).strip() |
|
|
species = species_mapping.get(predicted_class_idx, "Unknown") |
|
|
return caption, species, confidence |
|
|
|
|
|
# Example |
|
|
caption, species, confidence = predict_bird_image("/kaggle/input/cub2002011/CUB_200_2011/images/006.Least_Auklet/Least_Auklet_0007_795123.jpg") |
|
|
print(f"Caption: {caption}") |
|
|
print(f"Species: {species}") |
|
|
print(f"Confidence: {confidence:.2f}%") |
|
|
``` |
|
|
|
|
|
## Dataset |
|
|
- **CUB-200-2011**: 11,788 images of 200 bird species with attribute annotations. |
|
|
- Captions were generated based on species names and attributes (e.g., bill shape, wing color). |
|
|
|
|
|
## Training Details |
|
|
- **Loss**: Combined captioning (CrossEntropy) and classification (CrossEntropy) loss. |
|
|
- **Optimizer**: AdamW (lr=3e-5) |
|
|
- **Scheduler**: CosineAnnealingLR |
|
|
- **Hardware**: GPU (CUDA) |
|
|
- **Training Time**: ~5 min/epoch |
|
|
|
|
|
## Limitations |
|
|
- May overfit after Epoch 3 (validation loss increases). |
|
|
- Captions are limited to species and up to 5 attributes. |
|
|
- Classification accuracy not explicitly reported. |
|
|
|
|
|
## License |
|
|
MIT License |
|
|
|
|
|
## Contact |
|
|
For issues, contact INVERTO on Hugging Face. |
|
|
|