| --- |
| language: en |
| license: mit |
| tags: |
| - image-captioning |
| - pytorch |
| - resnet |
| - attention |
| - gru |
| - glove |
| - flickr8k |
| - flickr30k |
| - show-attend-and-tell |
| datasets: |
| - nlphuji/flickr8k |
| metrics: |
| - bleu |
| - meteor |
| - cider |
| - rouge |
| library_name: pytorch |
| pipeline_tag: image-to-text |
| --- |
| |
| # Flickr Image Captioning β ResNet50 + Bahdanau Attention + GRU + GloVe |
|
|
| This model generates a natural-language description of an image. It uses a |
| **ResNet50** spatial-feature encoder, a **Bahdanau (additive)** attention |
| module, and a **GRU decoder** initialized with **GloVe 6B 300d** embeddings, |
| trained on the merged **Flickr8k + Flickr30k** dataset (39,874 images Γ 5 |
| captions). It follows the architecture from |
| [*Show, Attend and Tell* (Xu et al., 2015)](https://arxiv.org/abs/1502.03044) |
| with label smoothing, scheduled sampling, and two-phase CNN fine-tuning. |
|
|
| ## Test-set performance (beam search, k = 5) |
|
|
| Evaluated on the held-out 1,873-image test split (image-level split β no |
| captions cross train/val/test). |
|
|
| | Metric | Value | |
| |---|---| |
| | BLEU-1 | 0.6859 | |
| | BLEU-2 | 0.5289 | |
| | BLEU-3 | 0.4041 | |
| | **BLEU-4** | **0.3093** | |
| | METEOR | 0.4709 | |
| | CIDEr | 0.7961 | |
| | ROUGE-L | 0.5257 | |
|
|
| Beam search uses length-normalized log-probs (`alpha = 0.7`) and a |
| repetition penalty of `1.2`. |
|
|
| ## Architecture |
|
|
| ``` |
| Image (3, 224, 224) |
| ββ ResNet50 (pretrained, frozen first 10 epochs, last 2 blocks fine-tuned) |
| output: (B, 2048, 7, 7) β reshape to (B, 49, 2048) |
| ββ Bahdanau attention VΒ·tanh(W_enc(features) + W_dec(h_prev)) |
| output: context vector (B, 2048), attention weights (B, 49) |
| ββ GRUCell (per timestep β re-queries attention each step) |
| hidden state size: 1024, embedding size: 300 (GloVe 6B 300d) |
| ββ Linear β vocab logits (V = 10,111) |
| ``` |
|
|
| Total parameters: **~37 M** (25 M frozen ResNet, 12 M trainable decoder/projection). |
|
|
| ## Training details |
|
|
| - **Dataset** β Flickr8k + Flickr30k merged (37,000 train / 1,000 val / 1,873 test) |
| - **Vocabulary** β 10,111 tokens (frequency threshold 3), built from train |
| captions only. Special tokens: `<pad>=0, <start>=1, <end>=2, `<unk>=3`. |
| - **Loss** β `CrossEntropyLoss(ignore_index=0, label_smoothing=0.1)` plus |
| doubly-stochastic regularization `Ξ±_c Β· ((1 β Ξ£_t Ξ±_t)Β²).mean()` with `Ξ±_c = 1.0` |
| - **Optimizer** β Adam, decoder LR `3.2e-3`, encoder LR `8e-5` (Phase B) |
| - **Schedule** β `ReduceLROnPlateau` on val BLEU-4, `factor=0.5`, `patience=3` |
| - **Two-phase training** β Phase A (epochs 1β10): freeze CNN. Phase B (epochs 11β35): unfreeze last 2 ResNet blocks. |
| - **Scheduled sampling** β linear ramp from 0 to max 0.25 over training epochs |
| - **Batch size** β 256, gradient clip 5.0, seed 42 |
|
|
| ## Files in this repo |
|
|
| - `attention_gru_glove.pth` β PyTorch checkpoint (encoder + decoder state dicts, config) |
| - `vocab.pkl` β pickled `Vocabulary` object built from the train split |
| - `config.json` β JSON copy of the training hyperparameters |
| - `metrics_beam5.json` β full test-set metrics (beam search k=5) |
|
|
| ## Usage |
|
|
| ```bash |
| git clone https://github.com/OmarGamal488/flickr-image-captioning.git |
| cd flickr-image-captioning |
| uv sync |
| ``` |
|
|
| Then in Python: |
|
|
| ```python |
| import pickle, torch |
| from huggingface_hub import hf_hub_download |
| from src.inference import load_attention_model, caption_image |
| from src.utils import get_device |
| |
| repo_id = "OmarGamal48812/flickr-captioning" |
| ckpt_path = hf_hub_download(repo_id=repo_id, filename="attention_gru_glove.pth") |
| vocab_path = hf_hub_download(repo_id=repo_id, filename="vocab.pkl") |
| |
| device = get_device() |
| with open(vocab_path, "rb") as f: |
| vocab = pickle.load(f) |
| |
| encoder, decoder, cfg = load_attention_model(ckpt_path, len(vocab), device) |
| |
| caption, beams = caption_image( |
| encoder, decoder, "your_image.jpg", vocab, device, |
| method="beam", beam_width=5, |
| ) |
| print(caption) |
| for b in beams[:3]: |
| print(f" {b.score:+.3f} {b.caption}") |
| ``` |
|
|
| ## Limitations |
|
|
| - **Domain.** Trained on Flickr8k + Flickr30k photos (mostly people, dogs, |
| outdoor scenes). Performance degrades on cartoons, screenshots, and abstract imagery. |
| - **Safe-word bias.** Only 8.8% of the 10,111-word vocabulary is used at inference β |
| the decoder converges on template phrases like *"a man in a white shirt is standing"*. |
| - **No object counting.** The attention context vector collapses object count β |
| the model often says "a dog" when the image shows two dogs. |
| - **Hallucinations.** The decoder can insert objects not in the image when visual |
| evidence is weak and the language-model prior takes over. |
| - **English only.** Vocabulary and grammar are entirely from English Flickr captions. |
|
|
| ## Citation |
|
|
| If you use this checkpoint, please cite the three papers this work builds on: |
|
|
| ```bibtex |
| @inproceedings{xu2015show, |
| title = {Show, Attend and Tell: Neural Image Caption Generation with Visual Attention}, |
| author = {Xu, Kelvin and Ba, Jimmy and Kiros, Ryan and Cho, Kyunghyun and Courville, Aaron and |
| Salakhutdinov, Ruslan and Zemel, Richard and Bengio, Yoshua}, |
| booktitle = {ICML}, |
| year = {2015} |
| } |
| |
| @article{bahdanau2014neural, |
| title = {Neural Machine Translation by Jointly Learning to Align and Translate}, |
| author = {Bahdanau, Dzmitry and Cho, Kyunghyun and Bengio, Yoshua}, |
| journal = {arXiv preprint arXiv:1409.0473}, |
| year = {2014} |
| } |
| |
| @inproceedings{selvaraju2017gradcam, |
| title = {Grad-{CAM}: Visual Explanations from Deep Networks via Gradient-based Localization}, |
| author = {Selvaraju, Ramprasaath R. and Cogswell, Michael and Das, Abhishek and |
| Vedantam, Ramakrishna and Parikh, Devi and Batra, Dhruv}, |
| booktitle = {ICCV}, |
| year = {2017} |
| } |
| ``` |
|
|