File size: 5,699 Bytes
3335cda 9d398b7 3335cda 9d398b7 3335cda 9d398b7 3335cda 9d398b7 3335cda 9d398b7 3335cda 9d398b7 3335cda 9d398b7 3335cda 9d398b7 3335cda 9d398b7 3335cda 9d398b7 3335cda 9d398b7 3335cda 9d398b7 3335cda 9d398b7 3335cda 9d398b7 3335cda 43a631c 9d398b7 3335cda 9d398b7 3335cda 9d398b7 3335cda 9d398b7 3335cda 9d398b7 3335cda 9d398b7 3335cda | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 | ---
language: en
license: mit
tags:
- image-captioning
- pytorch
- resnet
- attention
- gru
- glove
- flickr8k
- flickr30k
- show-attend-and-tell
datasets:
- nlphuji/flickr8k
metrics:
- bleu
- meteor
- cider
- rouge
library_name: pytorch
pipeline_tag: image-to-text
---
# Flickr Image Captioning β ResNet50 + Bahdanau Attention + GRU + GloVe
This model generates a natural-language description of an image. It uses a
**ResNet50** spatial-feature encoder, a **Bahdanau (additive)** attention
module, and a **GRU decoder** initialized with **GloVe 6B 300d** embeddings,
trained on the merged **Flickr8k + Flickr30k** dataset (39,874 images Γ 5
captions). It follows the architecture from
[*Show, Attend and Tell* (Xu et al., 2015)](https://arxiv.org/abs/1502.03044)
with label smoothing, scheduled sampling, and two-phase CNN fine-tuning.
## Test-set performance (beam search, k = 5)
Evaluated on the held-out 1,873-image test split (image-level split β no
captions cross train/val/test).
| Metric | Value |
|---|---|
| BLEU-1 | 0.6859 |
| BLEU-2 | 0.5289 |
| BLEU-3 | 0.4041 |
| **BLEU-4** | **0.3093** |
| METEOR | 0.4709 |
| CIDEr | 0.7961 |
| ROUGE-L | 0.5257 |
Beam search uses length-normalized log-probs (`alpha = 0.7`) and a
repetition penalty of `1.2`.
## Architecture
```
Image (3, 224, 224)
ββ ResNet50 (pretrained, frozen first 10 epochs, last 2 blocks fine-tuned)
output: (B, 2048, 7, 7) β reshape to (B, 49, 2048)
ββ Bahdanau attention VΒ·tanh(W_enc(features) + W_dec(h_prev))
output: context vector (B, 2048), attention weights (B, 49)
ββ GRUCell (per timestep β re-queries attention each step)
hidden state size: 1024, embedding size: 300 (GloVe 6B 300d)
ββ Linear β vocab logits (V = 10,111)
```
Total parameters: **~37 M** (25 M frozen ResNet, 12 M trainable decoder/projection).
## Training details
- **Dataset** β Flickr8k + Flickr30k merged (37,000 train / 1,000 val / 1,873 test)
- **Vocabulary** β 10,111 tokens (frequency threshold 3), built from train
captions only. Special tokens: `<pad>=0, <start>=1, <end>=2, `<unk>=3`.
- **Loss** β `CrossEntropyLoss(ignore_index=0, label_smoothing=0.1)` plus
doubly-stochastic regularization `Ξ±_c Β· ((1 β Ξ£_t Ξ±_t)Β²).mean()` with `Ξ±_c = 1.0`
- **Optimizer** β Adam, decoder LR `3.2e-3`, encoder LR `8e-5` (Phase B)
- **Schedule** β `ReduceLROnPlateau` on val BLEU-4, `factor=0.5`, `patience=3`
- **Two-phase training** β Phase A (epochs 1β10): freeze CNN. Phase B (epochs 11β35): unfreeze last 2 ResNet blocks.
- **Scheduled sampling** β linear ramp from 0 to max 0.25 over training epochs
- **Batch size** β 256, gradient clip 5.0, seed 42
## Files in this repo
- `attention_gru_glove.pth` β PyTorch checkpoint (encoder + decoder state dicts, config)
- `vocab.pkl` β pickled `Vocabulary` object built from the train split
- `config.json` β JSON copy of the training hyperparameters
- `metrics_beam5.json` β full test-set metrics (beam search k=5)
## Usage
```bash
git clone https://github.com/OmarGamal488/flickr-image-captioning.git
cd flickr-image-captioning
uv sync
```
Then in Python:
```python
import pickle, torch
from huggingface_hub import hf_hub_download
from src.inference import load_attention_model, caption_image
from src.utils import get_device
repo_id = "OmarGamal48812/flickr-captioning"
ckpt_path = hf_hub_download(repo_id=repo_id, filename="attention_gru_glove.pth")
vocab_path = hf_hub_download(repo_id=repo_id, filename="vocab.pkl")
device = get_device()
with open(vocab_path, "rb") as f:
vocab = pickle.load(f)
encoder, decoder, cfg = load_attention_model(ckpt_path, len(vocab), device)
caption, beams = caption_image(
encoder, decoder, "your_image.jpg", vocab, device,
method="beam", beam_width=5,
)
print(caption)
for b in beams[:3]:
print(f" {b.score:+.3f} {b.caption}")
```
## Limitations
- **Domain.** Trained on Flickr8k + Flickr30k photos (mostly people, dogs,
outdoor scenes). Performance degrades on cartoons, screenshots, and abstract imagery.
- **Safe-word bias.** Only 8.8% of the 10,111-word vocabulary is used at inference β
the decoder converges on template phrases like *"a man in a white shirt is standing"*.
- **No object counting.** The attention context vector collapses object count β
the model often says "a dog" when the image shows two dogs.
- **Hallucinations.** The decoder can insert objects not in the image when visual
evidence is weak and the language-model prior takes over.
- **English only.** Vocabulary and grammar are entirely from English Flickr captions.
## Citation
If you use this checkpoint, please cite the three papers this work builds on:
```bibtex
@inproceedings{xu2015show,
title = {Show, Attend and Tell: Neural Image Caption Generation with Visual Attention},
author = {Xu, Kelvin and Ba, Jimmy and Kiros, Ryan and Cho, Kyunghyun and Courville, Aaron and
Salakhutdinov, Ruslan and Zemel, Richard and Bengio, Yoshua},
booktitle = {ICML},
year = {2015}
}
@article{bahdanau2014neural,
title = {Neural Machine Translation by Jointly Learning to Align and Translate},
author = {Bahdanau, Dzmitry and Cho, Kyunghyun and Bengio, Yoshua},
journal = {arXiv preprint arXiv:1409.0473},
year = {2014}
}
@inproceedings{selvaraju2017gradcam,
title = {Grad-{CAM}: Visual Explanations from Deep Networks via Gradient-based Localization},
author = {Selvaraju, Ramprasaath R. and Cogswell, Michael and Das, Abhishek and
Vedantam, Ramakrishna and Parikh, Devi and Batra, Dhruv},
booktitle = {ICCV},
year = {2017}
}
```
|