--- language: en license: mit tags: - vision - image-captioning - image-classification - bird-species datasets: - cub-200-2011 --- # Bird Captioning and Classification Model (CUB-200-2011) This is a fine-tuned VisionEncoderDecoderModel based on `nlpconnect/vit-gpt2-image-captioning`, trained on the CUB-200-2011 dataset for bird species classification and image captioning. ## Model Description - **Base Model**: ViT-GPT2 (`nlpconnect/vit-gpt2-image-captioning`) - **Tasks**: - Generates descriptive captions for bird images, including species and attributes. - Classifies images into one of 200 bird species. - **Dataset**: CUB-200-2011 (11,788 images, 200 bird species) - **Training**: 10 epochs, batch size 16, mixed precision, AdamW optimizer (lr=3e-5), combined loss (caption + 0.5 * classification). - **Best Validation Loss**: 0.0690 (Epoch 3) ## Files - `model.safetensors`: Trained model weights - `config.json`: Model configuration - `preprocessor_config.json`: ViTImageProcessor settings - `tokenizer_config.json`, `vocab.json`: GPT2 tokenizer files - `species_mapping.txt`: Mapping of class indices to bird species names - `cub200_captions.csv`: Generated captions for the dataset - `model.py`: Custom `BirdCaptioningModel` class definition ## Usage ### Prerequisites ```bash pip install transformers torch huggingface_hub ``` ### Load Model and Dependencies ```python from transformers import ViTImageProcessor, AutoTokenizer from huggingface_hub import PyTorchModelHubMixin import torch from model import BirdCaptioningModel # Save model.py locally # Set device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load model model = BirdCaptioningModel.from_pretrained("INVERTO/bird-captioning-cub200").to(device) image_processor = ViTImageProcessor.from_pretrained("INVERTO/bird-captioning-cub200") tokenizer = AutoTokenizer.from_pretrained("INVERTO/bird-captioning-cub200") model.eval() # Load species mapping species_mapping = {} with open("species_mapping.txt", "r") as f: for line in f: idx, name = line.strip().split(",", 1) species_mapping[int(idx)] = name ``` ### Inference ```python from PIL import Image def predict_bird_image(image_path): image = Image.open(image_path).convert("RGB") pixel_values = image_processor(image, return_tensors="pt").pixel_values.to(device) with torch.no_grad(): output_ids = model.base_model.generate(pixel_values, max_length=75, num_beams=4) _, class_logits = model(pixel_values) predicted_class_idx = torch.argmax(class_logits, dim=1).item() confidence = torch.nn.functional.softmax(class_logits, dim=1)[0, predicted_class_idx].item() * 100 caption = tokenizer.decode(output_ids[0], skip_special_tokens=True).strip() species = species_mapping.get(predicted_class_idx, "Unknown") return caption, species, confidence # Example caption, species, confidence = predict_bird_image("/kaggle/input/cub2002011/CUB_200_2011/images/006.Least_Auklet/Least_Auklet_0007_795123.jpg") print(f"Caption: {caption}") print(f"Species: {species}") print(f"Confidence: {confidence:.2f}%") ``` ## Dataset - **CUB-200-2011**: 11,788 images of 200 bird species with attribute annotations. - Captions were generated based on species names and attributes (e.g., bill shape, wing color). ## Training Details - **Loss**: Combined captioning (CrossEntropy) and classification (CrossEntropy) loss. - **Optimizer**: AdamW (lr=3e-5) - **Scheduler**: CosineAnnealingLR - **Hardware**: GPU (CUDA) - **Training Time**: ~5 min/epoch ## Limitations - May overfit after Epoch 3 (validation loss increases). - Captions are limited to species and up to 5 attributes. - Classification accuracy not explicitly reported. ## License MIT License ## Contact For issues, contact INVERTO on Hugging Face.