Upload Flickr8k attention-LSTM checkpoint + model card
Browse files- README.md +171 -0
- attention_lstm.pth +3 -0
- config.json +34 -0
- metrics_beam5.json +17 -0
- metrics_greedy.json +17 -0
- vocab.pkl +3 -0
README.md
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
language: en
|
| 3 |
+
license: mit
|
| 4 |
+
tags:
|
| 5 |
+
- image-captioning
|
| 6 |
+
- pytorch
|
| 7 |
+
- resnet
|
| 8 |
+
- attention
|
| 9 |
+
- lstm
|
| 10 |
+
- flickr8k
|
| 11 |
+
- show-attend-and-tell
|
| 12 |
+
datasets:
|
| 13 |
+
- nlphuji/flickr8k
|
| 14 |
+
metrics:
|
| 15 |
+
- bleu
|
| 16 |
+
- meteor
|
| 17 |
+
- cider
|
| 18 |
+
- rouge
|
| 19 |
+
library_name: pytorch
|
| 20 |
+
pipeline_tag: image-to-text
|
| 21 |
+
---
|
| 22 |
+
|
| 23 |
+
# Flickr8k Image Captioning — ResNet50 + Bahdanau Attention + LSTM Decoder
|
| 24 |
+
|
| 25 |
+
This model generates a natural-language description of an image. It uses a
|
| 26 |
+
**ResNet50** spatial-feature encoder, a **Bahdanau (additive)** attention
|
| 27 |
+
module, and an **LSTM decoder**, trained with teacher forcing and doubly
|
| 28 |
+
stochastic regularization on the **Flickr8k** dataset (8,091 images × 5
|
| 29 |
+
captions). It is the reference architecture from
|
| 30 |
+
[*Show, Attend and Tell* (Xu et al., 2015)](https://arxiv.org/abs/1502.03044).
|
| 31 |
+
|
| 32 |
+
## Test-set performance (beam search, k = 5)
|
| 33 |
+
|
| 34 |
+
| Metric | Value |
|
| 35 |
+
|---|---|
|
| 36 |
+
| BLEU-1 | 0.6488 |
|
| 37 |
+
| BLEU-2 | 0.4714 |
|
| 38 |
+
| BLEU-3 | 0.3378 |
|
| 39 |
+
| **BLEU-4** | **0.2403** |
|
| 40 |
+
| METEOR | 0.4270 |
|
| 41 |
+
| CIDEr | 0.6002 |
|
| 42 |
+
| ROUGE-L | 0.4788 |
|
| 43 |
+
|
| 44 |
+
Greedy decoding scores: BLEU-4 = 0.2073, METEOR = 0.4119, CIDEr = 0.5322.
|
| 45 |
+
|
| 46 |
+
Evaluated on the held-out 1,091-image test split (image-level split — no
|
| 47 |
+
captions cross train/val/test). Beam search uses length-normalized log-probs
|
| 48 |
+
(`alpha = 0.7`) and a repetition penalty of `1.2`.
|
| 49 |
+
|
| 50 |
+
## Architecture
|
| 51 |
+
|
| 52 |
+
```
|
| 53 |
+
Image (3, 224, 224)
|
| 54 |
+
└─ ResNet50 (pretrained, frozen first 15 epochs, last 2 blocks fine-tuned)
|
| 55 |
+
output: (B, 2048, 7, 7) → reshape to (B, 49, 2048)
|
| 56 |
+
└─ Bahdanau attention V·tanh(W_enc(features) + W_dec(h_prev))
|
| 57 |
+
output: context vector (B, 2048), attention weights (B, 49)
|
| 58 |
+
└─ LSTMCell (per timestep — re-queries attention each step)
|
| 59 |
+
hidden state size: 512, embedding size: 256
|
| 60 |
+
└─ Linear → vocab logits (V = 2,557)
|
| 61 |
+
```
|
| 62 |
+
|
| 63 |
+
Total parameters: **~36 M** (28 M frozen ResNet, 8 M trainable decoder/projection).
|
| 64 |
+
|
| 65 |
+
## Training details
|
| 66 |
+
|
| 67 |
+
- **Loss** — `CrossEntropyLoss(ignore_index=0)` plus doubly-stochastic
|
| 68 |
+
regularization `α_c · ((1 − Σ_t α_t)²).mean()` with `α_c = 1.0`
|
| 69 |
+
- **Optimizer** — Adam, decoder LR `4e-4`, encoder LR `1e-5` (Phase B)
|
| 70 |
+
- **Schedule** — `ReduceLROnPlateau` on val BLEU-4, `factor=0.5`,
|
| 71 |
+
`patience=3`
|
| 72 |
+
- **Two-phase training** — Phase A (15 epochs): freeze CNN, train decoder
|
| 73 |
+
only. Phase B (10 epochs): unfreeze last 2 ResNet blocks.
|
| 74 |
+
- **Vocabulary** — 2,557 tokens (frequency threshold 5), built from train
|
| 75 |
+
captions only. Special tokens: `<pad>=0, <start>=1, <end>=2, <unk>=3`.
|
| 76 |
+
- **Batch size** — 32, gradient clip 5.0
|
| 77 |
+
- **Seed** — 42
|
| 78 |
+
|
| 79 |
+
## Files in this repo
|
| 80 |
+
|
| 81 |
+
- `attention_lstm.pth` — PyTorch checkpoint (encoder + decoder state
|
| 82 |
+
dicts, optimizer state, training config)
|
| 83 |
+
- `vocab.pkl` — pickled `Vocabulary` object built from the train split
|
| 84 |
+
- `config.json` — JSON copy of the training hyperparameters
|
| 85 |
+
- `metrics_beam5.json`, `metrics_greedy.json` — full test-set metrics
|
| 86 |
+
|
| 87 |
+
## Usage
|
| 88 |
+
|
| 89 |
+
The cleanest way to use this model is to clone the source repo so the
|
| 90 |
+
`Vocabulary`, encoder, and decoder classes are importable:
|
| 91 |
+
|
| 92 |
+
```bash
|
| 93 |
+
git clone https://github.com/OmarGamal488/flickr8k-image-captioning.git
|
| 94 |
+
cd flickr8k-image-captioning
|
| 95 |
+
uv sync
|
| 96 |
+
```
|
| 97 |
+
|
| 98 |
+
Then in Python:
|
| 99 |
+
|
| 100 |
+
```python
|
| 101 |
+
import pickle, torch
|
| 102 |
+
from huggingface_hub import hf_hub_download
|
| 103 |
+
from src.inference import load_attention_model, caption_image
|
| 104 |
+
from src.utils import get_device
|
| 105 |
+
|
| 106 |
+
repo_id = "OmarGamal48812/flickr8k-attention-lstm"
|
| 107 |
+
ckpt_path = hf_hub_download(repo_id=repo_id, filename="attention_lstm.pth")
|
| 108 |
+
vocab_path = hf_hub_download(repo_id=repo_id, filename="vocab.pkl")
|
| 109 |
+
|
| 110 |
+
device = get_device()
|
| 111 |
+
with open(vocab_path, "rb") as f:
|
| 112 |
+
vocab = pickle.load(f)
|
| 113 |
+
|
| 114 |
+
encoder, decoder, cfg = load_attention_model(ckpt_path, len(vocab), device)
|
| 115 |
+
|
| 116 |
+
caption, beams = caption_image(
|
| 117 |
+
encoder, decoder, "your_image.jpg", vocab, device,
|
| 118 |
+
method="beam", beam_width=5,
|
| 119 |
+
)
|
| 120 |
+
print(caption)
|
| 121 |
+
```
|
| 122 |
+
|
| 123 |
+
For interactive use, the same repo ships a Gradio demo (`app.py`) and a
|
| 124 |
+
FastAPI service (`api/main.py`).
|
| 125 |
+
|
| 126 |
+
## Limitations
|
| 127 |
+
|
| 128 |
+
- **Small training set.** Flickr8k has only 6,000 training images, so the
|
| 129 |
+
model often falls back to "safe" generic captions (e.g. *a dog runs through
|
| 130 |
+
the grass*) for unfamiliar scenes.
|
| 131 |
+
- **Vocabulary cap.** Words seen fewer than 5 times in the train split
|
| 132 |
+
collapse to `<unk>`. Rare nouns and proper names are systematically lost.
|
| 133 |
+
- **Domain.** Trained exclusively on Flickr8k photos (mostly people, dogs,
|
| 134 |
+
outdoor scenes). Performance degrades on cartoons, screenshots, abstract
|
| 135 |
+
imagery, and any scene type not represented in Flickr8k.
|
| 136 |
+
- **Hallucinations.** Like all autoregressive captioners, the decoder can
|
| 137 |
+
insert objects that aren't in the image when attention drifts.
|
| 138 |
+
- **English only.** Vocabulary and grammar are entirely English Flickr8k
|
| 139 |
+
captions.
|
| 140 |
+
|
| 141 |
+
## Intended use
|
| 142 |
+
|
| 143 |
+
Educational demonstrations of the Show-Attend-Tell architecture and
|
| 144 |
+
research baselines. Not appropriate as the only data source for
|
| 145 |
+
accessibility tooling (alt-text generation should ideally use a model
|
| 146 |
+
trained on a much larger dataset).
|
| 147 |
+
|
| 148 |
+
## Citation
|
| 149 |
+
|
| 150 |
+
If you use this checkpoint, please credit the underlying paper:
|
| 151 |
+
|
| 152 |
+
```bibtex
|
| 153 |
+
@inproceedings{xu2015show,
|
| 154 |
+
title = {Show, Attend and Tell: Neural Image Caption Generation with Visual Attention},
|
| 155 |
+
author = {Xu, Kelvin and Ba, Jimmy and Kiros, Ryan and Cho, Kyunghyun and Courville, Aaron and
|
| 156 |
+
Salakhutdinov, Ruslan and Zemel, Richard and Bengio, Yoshua},
|
| 157 |
+
booktitle = {ICML},
|
| 158 |
+
year = {2015}
|
| 159 |
+
}
|
| 160 |
+
```
|
| 161 |
+
|
| 162 |
+
and the dataset:
|
| 163 |
+
|
| 164 |
+
```bibtex
|
| 165 |
+
@article{hodosh2013framing,
|
| 166 |
+
title = {Framing Image Description as a Ranking Task: Data, Models and Evaluation Metrics},
|
| 167 |
+
author = {Hodosh, Micah and Young, Peter and Hockenmaier, Julia},
|
| 168 |
+
journal = {Journal of Artificial Intelligence Research},
|
| 169 |
+
year = {2013}
|
| 170 |
+
}
|
| 171 |
+
```
|
attention_lstm.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:91b41ebecea26453f6ce9ddab2702fb2db9f41dc17423a00a3d5d3ea9bfc8934
|
| 3 |
+
size 220277848
|
config.json
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"images_dir": "data/raw/Images",
|
| 3 |
+
"processed_dir": "data/processed",
|
| 4 |
+
|
| 5 |
+
"encoder_dim": 2048,
|
| 6 |
+
"embed_size": 256,
|
| 7 |
+
"hidden_size": 512,
|
| 8 |
+
"attention_dim": 256,
|
| 9 |
+
"dropout": 0.5,
|
| 10 |
+
"rnn_type": "lstm",
|
| 11 |
+
|
| 12 |
+
"alpha_c": 1.0,
|
| 13 |
+
|
| 14 |
+
"batch_size": 32,
|
| 15 |
+
"num_workers": 4,
|
| 16 |
+
"num_epochs": 25,
|
| 17 |
+
"decoder_lr": 4e-4,
|
| 18 |
+
"encoder_lr": 1e-5,
|
| 19 |
+
"weight_decay": 0.0,
|
| 20 |
+
"grad_clip": 5.0,
|
| 21 |
+
"scheduler_patience": 3,
|
| 22 |
+
"scheduler_factor": 0.5,
|
| 23 |
+
|
| 24 |
+
"fine_tune_start_epoch": 16,
|
| 25 |
+
"fine_tune_blocks": 2,
|
| 26 |
+
|
| 27 |
+
"seed": 42,
|
| 28 |
+
"save_dir": "models",
|
| 29 |
+
"run_name": "attention_lstm",
|
| 30 |
+
"log_interval": 50,
|
| 31 |
+
"val_bleu_subset": 200,
|
| 32 |
+
"wandb_project": "flickr8k-captioning",
|
| 33 |
+
"wandb_mode": "online"
|
| 34 |
+
}
|
metrics_beam5.json
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"checkpoint": "models/attention_lstm.pth",
|
| 3 |
+
"split": "test",
|
| 4 |
+
"n_images": 1091,
|
| 5 |
+
"method": "beam",
|
| 6 |
+
"beam_width": 5,
|
| 7 |
+
"max_len": 20,
|
| 8 |
+
"rnn_type": "lstm",
|
| 9 |
+
"BLEU-1": 0.6488,
|
| 10 |
+
"BLEU-2": 0.4714,
|
| 11 |
+
"BLEU-3": 0.3378,
|
| 12 |
+
"BLEU-4": 0.2403,
|
| 13 |
+
"METEOR": 0.427,
|
| 14 |
+
"CIDEr": 0.6002,
|
| 15 |
+
"ROUGE-L": 0.4788,
|
| 16 |
+
"wall_clock_s": 21.1
|
| 17 |
+
}
|
metrics_greedy.json
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"checkpoint": "models/attention_lstm.pth",
|
| 3 |
+
"split": "test",
|
| 4 |
+
"n_images": 1091,
|
| 5 |
+
"method": "greedy",
|
| 6 |
+
"beam_width": null,
|
| 7 |
+
"max_len": 20,
|
| 8 |
+
"rnn_type": "lstm",
|
| 9 |
+
"BLEU-1": 0.6342,
|
| 10 |
+
"BLEU-2": 0.4485,
|
| 11 |
+
"BLEU-3": 0.3057,
|
| 12 |
+
"BLEU-4": 0.2073,
|
| 13 |
+
"METEOR": 0.4119,
|
| 14 |
+
"CIDEr": 0.5322,
|
| 15 |
+
"ROUGE-L": 0.4654,
|
| 16 |
+
"wall_clock_s": 6.1
|
| 17 |
+
}
|
vocab.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:30adbb8a77440e549df89caf254b77ebb0c269fdc6bec5ee3a3a79f310521c07
|
| 3 |
+
size 126102
|