File size: 5,699 Bytes
3335cda
 
 
 
 
 
 
 
9d398b7
 
3335cda
9d398b7
3335cda
 
 
 
 
 
 
 
 
 
 
 
9d398b7
3335cda
 
 
9d398b7
 
 
 
 
3335cda
 
 
9d398b7
 
 
3335cda
 
9d398b7
 
 
 
 
 
 
3335cda
9d398b7
 
3335cda
 
 
 
 
9d398b7
3335cda
 
 
9d398b7
 
 
3335cda
 
9d398b7
3335cda
 
 
9d398b7
 
 
 
 
 
 
 
 
 
3335cda
 
 
9d398b7
3335cda
 
9d398b7
3335cda
 
 
 
9d398b7
 
3335cda
 
 
 
 
 
 
 
 
 
 
43a631c
9d398b7
3335cda
 
 
 
 
 
 
 
 
 
 
 
 
9d398b7
 
3335cda
 
 
 
9d398b7
 
 
 
 
 
 
 
 
3335cda
 
 
9d398b7
3335cda
 
 
 
 
 
 
 
 
 
9d398b7
 
 
 
 
 
3335cda
9d398b7
 
 
 
 
 
3335cda
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
---
language: en
license: mit
tags:
  - image-captioning
  - pytorch
  - resnet
  - attention
  - gru
  - glove
  - flickr8k
  - flickr30k
  - show-attend-and-tell
datasets:
  - nlphuji/flickr8k
metrics:
  - bleu
  - meteor
  - cider
  - rouge
library_name: pytorch
pipeline_tag: image-to-text
---

# Flickr Image Captioning β€” ResNet50 + Bahdanau Attention + GRU + GloVe

This model generates a natural-language description of an image. It uses a
**ResNet50** spatial-feature encoder, a **Bahdanau (additive)** attention
module, and a **GRU decoder** initialized with **GloVe 6B 300d** embeddings,
trained on the merged **Flickr8k + Flickr30k** dataset (39,874 images Γ— 5
captions). It follows the architecture from
[*Show, Attend and Tell* (Xu et al., 2015)](https://arxiv.org/abs/1502.03044)
with label smoothing, scheduled sampling, and two-phase CNN fine-tuning.

## Test-set performance (beam search, k = 5)

Evaluated on the held-out 1,873-image test split (image-level split β€” no
captions cross train/val/test).

| Metric | Value |
|---|---|
| BLEU-1 | 0.6859 |
| BLEU-2 | 0.5289 |
| BLEU-3 | 0.4041 |
| **BLEU-4** | **0.3093** |
| METEOR | 0.4709 |
| CIDEr | 0.7961 |
| ROUGE-L | 0.5257 |

Beam search uses length-normalized log-probs (`alpha = 0.7`) and a
repetition penalty of `1.2`.

## Architecture

```
Image (3, 224, 224)
  └─ ResNet50 (pretrained, frozen first 10 epochs, last 2 blocks fine-tuned)
       output: (B, 2048, 7, 7)  β†’ reshape to (B, 49, 2048)
  └─ Bahdanau attention  VΒ·tanh(W_enc(features) + W_dec(h_prev))
       output: context vector (B, 2048), attention weights (B, 49)
  └─ GRUCell  (per timestep β€” re-queries attention each step)
       hidden state size: 1024, embedding size: 300 (GloVe 6B 300d)
  └─ Linear β†’ vocab logits (V = 10,111)
```

Total parameters: **~37 M** (25 M frozen ResNet, 12 M trainable decoder/projection).

## Training details

- **Dataset** β€” Flickr8k + Flickr30k merged (37,000 train / 1,000 val / 1,873 test)
- **Vocabulary** β€” 10,111 tokens (frequency threshold 3), built from train
  captions only. Special tokens: `<pad>=0, <start>=1, <end>=2, `<unk>=3`.
- **Loss** β€” `CrossEntropyLoss(ignore_index=0, label_smoothing=0.1)` plus
  doubly-stochastic regularization `Ξ±_c Β· ((1 βˆ’ Ξ£_t Ξ±_t)Β²).mean()` with `Ξ±_c = 1.0`
- **Optimizer** β€” Adam, decoder LR `3.2e-3`, encoder LR `8e-5` (Phase B)
- **Schedule** β€” `ReduceLROnPlateau` on val BLEU-4, `factor=0.5`, `patience=3`
- **Two-phase training** β€” Phase A (epochs 1–10): freeze CNN. Phase B (epochs 11–35): unfreeze last 2 ResNet blocks.
- **Scheduled sampling** β€” linear ramp from 0 to max 0.25 over training epochs
- **Batch size** β€” 256, gradient clip 5.0, seed 42

## Files in this repo

- `attention_gru_glove.pth` β€” PyTorch checkpoint (encoder + decoder state dicts, config)
- `vocab.pkl` β€” pickled `Vocabulary` object built from the train split
- `config.json` β€” JSON copy of the training hyperparameters
- `metrics_beam5.json` β€” full test-set metrics (beam search k=5)

## Usage

```bash
git clone https://github.com/OmarGamal488/flickr-image-captioning.git
cd flickr-image-captioning
uv sync
```

Then in Python:

```python
import pickle, torch
from huggingface_hub import hf_hub_download
from src.inference import load_attention_model, caption_image
from src.utils import get_device

repo_id = "OmarGamal48812/flickr-captioning"
ckpt_path = hf_hub_download(repo_id=repo_id, filename="attention_gru_glove.pth")
vocab_path = hf_hub_download(repo_id=repo_id, filename="vocab.pkl")

device = get_device()
with open(vocab_path, "rb") as f:
    vocab = pickle.load(f)

encoder, decoder, cfg = load_attention_model(ckpt_path, len(vocab), device)

caption, beams = caption_image(
    encoder, decoder, "your_image.jpg", vocab, device,
    method="beam", beam_width=5,
)
print(caption)
for b in beams[:3]:
    print(f"  {b.score:+.3f}  {b.caption}")
```

## Limitations

- **Domain.** Trained on Flickr8k + Flickr30k photos (mostly people, dogs,
  outdoor scenes). Performance degrades on cartoons, screenshots, and abstract imagery.
- **Safe-word bias.** Only 8.8% of the 10,111-word vocabulary is used at inference β€”
  the decoder converges on template phrases like *"a man in a white shirt is standing"*.
- **No object counting.** The attention context vector collapses object count β€”
  the model often says "a dog" when the image shows two dogs.
- **Hallucinations.** The decoder can insert objects not in the image when visual
  evidence is weak and the language-model prior takes over.
- **English only.** Vocabulary and grammar are entirely from English Flickr captions.

## Citation

If you use this checkpoint, please cite the three papers this work builds on:

```bibtex
@inproceedings{xu2015show,
  title     = {Show, Attend and Tell: Neural Image Caption Generation with Visual Attention},
  author    = {Xu, Kelvin and Ba, Jimmy and Kiros, Ryan and Cho, Kyunghyun and Courville, Aaron and
               Salakhutdinov, Ruslan and Zemel, Richard and Bengio, Yoshua},
  booktitle = {ICML},
  year      = {2015}
}

@article{bahdanau2014neural,
  title   = {Neural Machine Translation by Jointly Learning to Align and Translate},
  author  = {Bahdanau, Dzmitry and Cho, Kyunghyun and Bengio, Yoshua},
  journal = {arXiv preprint arXiv:1409.0473},
  year    = {2014}
}

@inproceedings{selvaraju2017gradcam,
  title     = {Grad-{CAM}: Visual Explanations from Deep Networks via Gradient-based Localization},
  author    = {Selvaraju, Ramprasaath R. and Cogswell, Michael and Das, Abhishek and
               Vedantam, Ramakrishna and Parikh, Devi and Batra, Dhruv},
  booktitle = {ICCV},
  year      = {2017}
}
```