flickr-captioning / README.md
OmarGamal48812's picture
Upload README.md with huggingface_hub
43a631c verified
---
language: en
license: mit
tags:
- image-captioning
- pytorch
- resnet
- attention
- gru
- glove
- flickr8k
- flickr30k
- show-attend-and-tell
datasets:
- nlphuji/flickr8k
metrics:
- bleu
- meteor
- cider
- rouge
library_name: pytorch
pipeline_tag: image-to-text
---
# Flickr Image Captioning β€” ResNet50 + Bahdanau Attention + GRU + GloVe
This model generates a natural-language description of an image. It uses a
**ResNet50** spatial-feature encoder, a **Bahdanau (additive)** attention
module, and a **GRU decoder** initialized with **GloVe 6B 300d** embeddings,
trained on the merged **Flickr8k + Flickr30k** dataset (39,874 images Γ— 5
captions). It follows the architecture from
[*Show, Attend and Tell* (Xu et al., 2015)](https://arxiv.org/abs/1502.03044)
with label smoothing, scheduled sampling, and two-phase CNN fine-tuning.
## Test-set performance (beam search, k = 5)
Evaluated on the held-out 1,873-image test split (image-level split β€” no
captions cross train/val/test).
| Metric | Value |
|---|---|
| BLEU-1 | 0.6859 |
| BLEU-2 | 0.5289 |
| BLEU-3 | 0.4041 |
| **BLEU-4** | **0.3093** |
| METEOR | 0.4709 |
| CIDEr | 0.7961 |
| ROUGE-L | 0.5257 |
Beam search uses length-normalized log-probs (`alpha = 0.7`) and a
repetition penalty of `1.2`.
## Architecture
```
Image (3, 224, 224)
└─ ResNet50 (pretrained, frozen first 10 epochs, last 2 blocks fine-tuned)
output: (B, 2048, 7, 7) β†’ reshape to (B, 49, 2048)
└─ Bahdanau attention VΒ·tanh(W_enc(features) + W_dec(h_prev))
output: context vector (B, 2048), attention weights (B, 49)
└─ GRUCell (per timestep β€” re-queries attention each step)
hidden state size: 1024, embedding size: 300 (GloVe 6B 300d)
└─ Linear β†’ vocab logits (V = 10,111)
```
Total parameters: **~37 M** (25 M frozen ResNet, 12 M trainable decoder/projection).
## Training details
- **Dataset** β€” Flickr8k + Flickr30k merged (37,000 train / 1,000 val / 1,873 test)
- **Vocabulary** β€” 10,111 tokens (frequency threshold 3), built from train
captions only. Special tokens: `<pad>=0, <start>=1, <end>=2, `<unk>=3`.
- **Loss** β€” `CrossEntropyLoss(ignore_index=0, label_smoothing=0.1)` plus
doubly-stochastic regularization `Ξ±_c Β· ((1 βˆ’ Ξ£_t Ξ±_t)Β²).mean()` with `Ξ±_c = 1.0`
- **Optimizer** β€” Adam, decoder LR `3.2e-3`, encoder LR `8e-5` (Phase B)
- **Schedule** β€” `ReduceLROnPlateau` on val BLEU-4, `factor=0.5`, `patience=3`
- **Two-phase training** β€” Phase A (epochs 1–10): freeze CNN. Phase B (epochs 11–35): unfreeze last 2 ResNet blocks.
- **Scheduled sampling** β€” linear ramp from 0 to max 0.25 over training epochs
- **Batch size** β€” 256, gradient clip 5.0, seed 42
## Files in this repo
- `attention_gru_glove.pth` β€” PyTorch checkpoint (encoder + decoder state dicts, config)
- `vocab.pkl` β€” pickled `Vocabulary` object built from the train split
- `config.json` β€” JSON copy of the training hyperparameters
- `metrics_beam5.json` β€” full test-set metrics (beam search k=5)
## Usage
```bash
git clone https://github.com/OmarGamal488/flickr-image-captioning.git
cd flickr-image-captioning
uv sync
```
Then in Python:
```python
import pickle, torch
from huggingface_hub import hf_hub_download
from src.inference import load_attention_model, caption_image
from src.utils import get_device
repo_id = "OmarGamal48812/flickr-captioning"
ckpt_path = hf_hub_download(repo_id=repo_id, filename="attention_gru_glove.pth")
vocab_path = hf_hub_download(repo_id=repo_id, filename="vocab.pkl")
device = get_device()
with open(vocab_path, "rb") as f:
vocab = pickle.load(f)
encoder, decoder, cfg = load_attention_model(ckpt_path, len(vocab), device)
caption, beams = caption_image(
encoder, decoder, "your_image.jpg", vocab, device,
method="beam", beam_width=5,
)
print(caption)
for b in beams[:3]:
print(f" {b.score:+.3f} {b.caption}")
```
## Limitations
- **Domain.** Trained on Flickr8k + Flickr30k photos (mostly people, dogs,
outdoor scenes). Performance degrades on cartoons, screenshots, and abstract imagery.
- **Safe-word bias.** Only 8.8% of the 10,111-word vocabulary is used at inference β€”
the decoder converges on template phrases like *"a man in a white shirt is standing"*.
- **No object counting.** The attention context vector collapses object count β€”
the model often says "a dog" when the image shows two dogs.
- **Hallucinations.** The decoder can insert objects not in the image when visual
evidence is weak and the language-model prior takes over.
- **English only.** Vocabulary and grammar are entirely from English Flickr captions.
## Citation
If you use this checkpoint, please cite the three papers this work builds on:
```bibtex
@inproceedings{xu2015show,
title = {Show, Attend and Tell: Neural Image Caption Generation with Visual Attention},
author = {Xu, Kelvin and Ba, Jimmy and Kiros, Ryan and Cho, Kyunghyun and Courville, Aaron and
Salakhutdinov, Ruslan and Zemel, Richard and Bengio, Yoshua},
booktitle = {ICML},
year = {2015}
}
@article{bahdanau2014neural,
title = {Neural Machine Translation by Jointly Learning to Align and Translate},
author = {Bahdanau, Dzmitry and Cho, Kyunghyun and Bengio, Yoshua},
journal = {arXiv preprint arXiv:1409.0473},
year = {2014}
}
@inproceedings{selvaraju2017gradcam,
title = {Grad-{CAM}: Visual Explanations from Deep Networks via Gradient-based Localization},
author = {Selvaraju, Ramprasaath R. and Cogswell, Michael and Das, Abhishek and
Vedantam, Ramakrishna and Parikh, Devi and Batra, Dhruv},
booktitle = {ICCV},
year = {2017}
}
```