--- language: en license: mit tags: - image-captioning - pytorch - resnet - attention - gru - glove - flickr8k - flickr30k - show-attend-and-tell datasets: - nlphuji/flickr8k metrics: - bleu - meteor - cider - rouge library_name: pytorch pipeline_tag: image-to-text --- # Flickr Image Captioning — ResNet50 + Bahdanau Attention + GRU + GloVe This model generates a natural-language description of an image. It uses a **ResNet50** spatial-feature encoder, a **Bahdanau (additive)** attention module, and a **GRU decoder** initialized with **GloVe 6B 300d** embeddings, trained on the merged **Flickr8k + Flickr30k** dataset (39,874 images × 5 captions). It follows the architecture from [*Show, Attend and Tell* (Xu et al., 2015)](https://arxiv.org/abs/1502.03044) with label smoothing, scheduled sampling, and two-phase CNN fine-tuning. ## Test-set performance (beam search, k = 5) Evaluated on the held-out 1,873-image test split (image-level split — no captions cross train/val/test). | Metric | Value | |---|---| | BLEU-1 | 0.6859 | | BLEU-2 | 0.5289 | | BLEU-3 | 0.4041 | | **BLEU-4** | **0.3093** | | METEOR | 0.4709 | | CIDEr | 0.7961 | | ROUGE-L | 0.5257 | Beam search uses length-normalized log-probs (`alpha = 0.7`) and a repetition penalty of `1.2`. ## Architecture ``` Image (3, 224, 224) └─ ResNet50 (pretrained, frozen first 10 epochs, last 2 blocks fine-tuned) output: (B, 2048, 7, 7) → reshape to (B, 49, 2048) └─ Bahdanau attention V·tanh(W_enc(features) + W_dec(h_prev)) output: context vector (B, 2048), attention weights (B, 49) └─ GRUCell (per timestep — re-queries attention each step) hidden state size: 1024, embedding size: 300 (GloVe 6B 300d) └─ Linear → vocab logits (V = 10,111) ``` Total parameters: **~37 M** (25 M frozen ResNet, 12 M trainable decoder/projection). ## Training details - **Dataset** — Flickr8k + Flickr30k merged (37,000 train / 1,000 val / 1,873 test) - **Vocabulary** — 10,111 tokens (frequency threshold 3), built from train captions only. Special tokens: `=0, =1, =2, `=3`. - **Loss** — `CrossEntropyLoss(ignore_index=0, label_smoothing=0.1)` plus doubly-stochastic regularization `α_c · ((1 − Σ_t α_t)²).mean()` with `α_c = 1.0` - **Optimizer** — Adam, decoder LR `3.2e-3`, encoder LR `8e-5` (Phase B) - **Schedule** — `ReduceLROnPlateau` on val BLEU-4, `factor=0.5`, `patience=3` - **Two-phase training** — Phase A (epochs 1–10): freeze CNN. Phase B (epochs 11–35): unfreeze last 2 ResNet blocks. - **Scheduled sampling** — linear ramp from 0 to max 0.25 over training epochs - **Batch size** — 256, gradient clip 5.0, seed 42 ## Files in this repo - `attention_gru_glove.pth` — PyTorch checkpoint (encoder + decoder state dicts, config) - `vocab.pkl` — pickled `Vocabulary` object built from the train split - `config.json` — JSON copy of the training hyperparameters - `metrics_beam5.json` — full test-set metrics (beam search k=5) ## Usage ```bash git clone https://github.com/OmarGamal488/flickr-image-captioning.git cd flickr-image-captioning uv sync ``` Then in Python: ```python import pickle, torch from huggingface_hub import hf_hub_download from src.inference import load_attention_model, caption_image from src.utils import get_device repo_id = "OmarGamal48812/flickr-captioning" ckpt_path = hf_hub_download(repo_id=repo_id, filename="attention_gru_glove.pth") vocab_path = hf_hub_download(repo_id=repo_id, filename="vocab.pkl") device = get_device() with open(vocab_path, "rb") as f: vocab = pickle.load(f) encoder, decoder, cfg = load_attention_model(ckpt_path, len(vocab), device) caption, beams = caption_image( encoder, decoder, "your_image.jpg", vocab, device, method="beam", beam_width=5, ) print(caption) for b in beams[:3]: print(f" {b.score:+.3f} {b.caption}") ``` ## Limitations - **Domain.** Trained on Flickr8k + Flickr30k photos (mostly people, dogs, outdoor scenes). Performance degrades on cartoons, screenshots, and abstract imagery. - **Safe-word bias.** Only 8.8% of the 10,111-word vocabulary is used at inference — the decoder converges on template phrases like *"a man in a white shirt is standing"*. - **No object counting.** The attention context vector collapses object count — the model often says "a dog" when the image shows two dogs. - **Hallucinations.** The decoder can insert objects not in the image when visual evidence is weak and the language-model prior takes over. - **English only.** Vocabulary and grammar are entirely from English Flickr captions. ## Citation If you use this checkpoint, please cite the three papers this work builds on: ```bibtex @inproceedings{xu2015show, title = {Show, Attend and Tell: Neural Image Caption Generation with Visual Attention}, author = {Xu, Kelvin and Ba, Jimmy and Kiros, Ryan and Cho, Kyunghyun and Courville, Aaron and Salakhutdinov, Ruslan and Zemel, Richard and Bengio, Yoshua}, booktitle = {ICML}, year = {2015} } @article{bahdanau2014neural, title = {Neural Machine Translation by Jointly Learning to Align and Translate}, author = {Bahdanau, Dzmitry and Cho, Kyunghyun and Bengio, Yoshua}, journal = {arXiv preprint arXiv:1409.0473}, year = {2014} } @inproceedings{selvaraju2017gradcam, title = {Grad-{CAM}: Visual Explanations from Deep Networks via Gradient-based Localization}, author = {Selvaraju, Ramprasaath R. and Cogswell, Michael and Das, Abhishek and Vedantam, Ramakrishna and Parikh, Devi and Batra, Dhruv}, booktitle = {ICCV}, year = {2017} } ```