OmarGamal48812 commited on
Commit
3335cda
·
verified ·
1 Parent(s): 3b1e619

Upload Flickr8k attention-LSTM checkpoint + model card

Browse files
Files changed (6) hide show
  1. README.md +171 -0
  2. attention_lstm.pth +3 -0
  3. config.json +34 -0
  4. metrics_beam5.json +17 -0
  5. metrics_greedy.json +17 -0
  6. vocab.pkl +3 -0
README.md ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: en
3
+ license: mit
4
+ tags:
5
+ - image-captioning
6
+ - pytorch
7
+ - resnet
8
+ - attention
9
+ - lstm
10
+ - flickr8k
11
+ - show-attend-and-tell
12
+ datasets:
13
+ - nlphuji/flickr8k
14
+ metrics:
15
+ - bleu
16
+ - meteor
17
+ - cider
18
+ - rouge
19
+ library_name: pytorch
20
+ pipeline_tag: image-to-text
21
+ ---
22
+
23
+ # Flickr8k Image Captioning — ResNet50 + Bahdanau Attention + LSTM Decoder
24
+
25
+ This model generates a natural-language description of an image. It uses a
26
+ **ResNet50** spatial-feature encoder, a **Bahdanau (additive)** attention
27
+ module, and an **LSTM decoder**, trained with teacher forcing and doubly
28
+ stochastic regularization on the **Flickr8k** dataset (8,091 images × 5
29
+ captions). It is the reference architecture from
30
+ [*Show, Attend and Tell* (Xu et al., 2015)](https://arxiv.org/abs/1502.03044).
31
+
32
+ ## Test-set performance (beam search, k = 5)
33
+
34
+ | Metric | Value |
35
+ |---|---|
36
+ | BLEU-1 | 0.6488 |
37
+ | BLEU-2 | 0.4714 |
38
+ | BLEU-3 | 0.3378 |
39
+ | **BLEU-4** | **0.2403** |
40
+ | METEOR | 0.4270 |
41
+ | CIDEr | 0.6002 |
42
+ | ROUGE-L | 0.4788 |
43
+
44
+ Greedy decoding scores: BLEU-4 = 0.2073, METEOR = 0.4119, CIDEr = 0.5322.
45
+
46
+ Evaluated on the held-out 1,091-image test split (image-level split — no
47
+ captions cross train/val/test). Beam search uses length-normalized log-probs
48
+ (`alpha = 0.7`) and a repetition penalty of `1.2`.
49
+
50
+ ## Architecture
51
+
52
+ ```
53
+ Image (3, 224, 224)
54
+ └─ ResNet50 (pretrained, frozen first 15 epochs, last 2 blocks fine-tuned)
55
+ output: (B, 2048, 7, 7) → reshape to (B, 49, 2048)
56
+ └─ Bahdanau attention V·tanh(W_enc(features) + W_dec(h_prev))
57
+ output: context vector (B, 2048), attention weights (B, 49)
58
+ └─ LSTMCell (per timestep — re-queries attention each step)
59
+ hidden state size: 512, embedding size: 256
60
+ └─ Linear → vocab logits (V = 2,557)
61
+ ```
62
+
63
+ Total parameters: **~36 M** (28 M frozen ResNet, 8 M trainable decoder/projection).
64
+
65
+ ## Training details
66
+
67
+ - **Loss** — `CrossEntropyLoss(ignore_index=0)` plus doubly-stochastic
68
+ regularization `α_c · ((1 − Σ_t α_t)²).mean()` with `α_c = 1.0`
69
+ - **Optimizer** — Adam, decoder LR `4e-4`, encoder LR `1e-5` (Phase B)
70
+ - **Schedule** — `ReduceLROnPlateau` on val BLEU-4, `factor=0.5`,
71
+ `patience=3`
72
+ - **Two-phase training** — Phase A (15 epochs): freeze CNN, train decoder
73
+ only. Phase B (10 epochs): unfreeze last 2 ResNet blocks.
74
+ - **Vocabulary** — 2,557 tokens (frequency threshold 5), built from train
75
+ captions only. Special tokens: `<pad>=0, <start>=1, <end>=2, <unk>=3`.
76
+ - **Batch size** — 32, gradient clip 5.0
77
+ - **Seed** — 42
78
+
79
+ ## Files in this repo
80
+
81
+ - `attention_lstm.pth` — PyTorch checkpoint (encoder + decoder state
82
+ dicts, optimizer state, training config)
83
+ - `vocab.pkl` — pickled `Vocabulary` object built from the train split
84
+ - `config.json` — JSON copy of the training hyperparameters
85
+ - `metrics_beam5.json`, `metrics_greedy.json` — full test-set metrics
86
+
87
+ ## Usage
88
+
89
+ The cleanest way to use this model is to clone the source repo so the
90
+ `Vocabulary`, encoder, and decoder classes are importable:
91
+
92
+ ```bash
93
+ git clone https://github.com/OmarGamal488/flickr8k-image-captioning.git
94
+ cd flickr8k-image-captioning
95
+ uv sync
96
+ ```
97
+
98
+ Then in Python:
99
+
100
+ ```python
101
+ import pickle, torch
102
+ from huggingface_hub import hf_hub_download
103
+ from src.inference import load_attention_model, caption_image
104
+ from src.utils import get_device
105
+
106
+ repo_id = "OmarGamal48812/flickr8k-attention-lstm"
107
+ ckpt_path = hf_hub_download(repo_id=repo_id, filename="attention_lstm.pth")
108
+ vocab_path = hf_hub_download(repo_id=repo_id, filename="vocab.pkl")
109
+
110
+ device = get_device()
111
+ with open(vocab_path, "rb") as f:
112
+ vocab = pickle.load(f)
113
+
114
+ encoder, decoder, cfg = load_attention_model(ckpt_path, len(vocab), device)
115
+
116
+ caption, beams = caption_image(
117
+ encoder, decoder, "your_image.jpg", vocab, device,
118
+ method="beam", beam_width=5,
119
+ )
120
+ print(caption)
121
+ ```
122
+
123
+ For interactive use, the same repo ships a Gradio demo (`app.py`) and a
124
+ FastAPI service (`api/main.py`).
125
+
126
+ ## Limitations
127
+
128
+ - **Small training set.** Flickr8k has only 6,000 training images, so the
129
+ model often falls back to "safe" generic captions (e.g. *a dog runs through
130
+ the grass*) for unfamiliar scenes.
131
+ - **Vocabulary cap.** Words seen fewer than 5 times in the train split
132
+ collapse to `<unk>`. Rare nouns and proper names are systematically lost.
133
+ - **Domain.** Trained exclusively on Flickr8k photos (mostly people, dogs,
134
+ outdoor scenes). Performance degrades on cartoons, screenshots, abstract
135
+ imagery, and any scene type not represented in Flickr8k.
136
+ - **Hallucinations.** Like all autoregressive captioners, the decoder can
137
+ insert objects that aren't in the image when attention drifts.
138
+ - **English only.** Vocabulary and grammar are entirely English Flickr8k
139
+ captions.
140
+
141
+ ## Intended use
142
+
143
+ Educational demonstrations of the Show-Attend-Tell architecture and
144
+ research baselines. Not appropriate as the only data source for
145
+ accessibility tooling (alt-text generation should ideally use a model
146
+ trained on a much larger dataset).
147
+
148
+ ## Citation
149
+
150
+ If you use this checkpoint, please credit the underlying paper:
151
+
152
+ ```bibtex
153
+ @inproceedings{xu2015show,
154
+ title = {Show, Attend and Tell: Neural Image Caption Generation with Visual Attention},
155
+ author = {Xu, Kelvin and Ba, Jimmy and Kiros, Ryan and Cho, Kyunghyun and Courville, Aaron and
156
+ Salakhutdinov, Ruslan and Zemel, Richard and Bengio, Yoshua},
157
+ booktitle = {ICML},
158
+ year = {2015}
159
+ }
160
+ ```
161
+
162
+ and the dataset:
163
+
164
+ ```bibtex
165
+ @article{hodosh2013framing,
166
+ title = {Framing Image Description as a Ranking Task: Data, Models and Evaluation Metrics},
167
+ author = {Hodosh, Micah and Young, Peter and Hockenmaier, Julia},
168
+ journal = {Journal of Artificial Intelligence Research},
169
+ year = {2013}
170
+ }
171
+ ```
attention_lstm.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:91b41ebecea26453f6ce9ddab2702fb2db9f41dc17423a00a3d5d3ea9bfc8934
3
+ size 220277848
config.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "images_dir": "data/raw/Images",
3
+ "processed_dir": "data/processed",
4
+
5
+ "encoder_dim": 2048,
6
+ "embed_size": 256,
7
+ "hidden_size": 512,
8
+ "attention_dim": 256,
9
+ "dropout": 0.5,
10
+ "rnn_type": "lstm",
11
+
12
+ "alpha_c": 1.0,
13
+
14
+ "batch_size": 32,
15
+ "num_workers": 4,
16
+ "num_epochs": 25,
17
+ "decoder_lr": 4e-4,
18
+ "encoder_lr": 1e-5,
19
+ "weight_decay": 0.0,
20
+ "grad_clip": 5.0,
21
+ "scheduler_patience": 3,
22
+ "scheduler_factor": 0.5,
23
+
24
+ "fine_tune_start_epoch": 16,
25
+ "fine_tune_blocks": 2,
26
+
27
+ "seed": 42,
28
+ "save_dir": "models",
29
+ "run_name": "attention_lstm",
30
+ "log_interval": 50,
31
+ "val_bleu_subset": 200,
32
+ "wandb_project": "flickr8k-captioning",
33
+ "wandb_mode": "online"
34
+ }
metrics_beam5.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "checkpoint": "models/attention_lstm.pth",
3
+ "split": "test",
4
+ "n_images": 1091,
5
+ "method": "beam",
6
+ "beam_width": 5,
7
+ "max_len": 20,
8
+ "rnn_type": "lstm",
9
+ "BLEU-1": 0.6488,
10
+ "BLEU-2": 0.4714,
11
+ "BLEU-3": 0.3378,
12
+ "BLEU-4": 0.2403,
13
+ "METEOR": 0.427,
14
+ "CIDEr": 0.6002,
15
+ "ROUGE-L": 0.4788,
16
+ "wall_clock_s": 21.1
17
+ }
metrics_greedy.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "checkpoint": "models/attention_lstm.pth",
3
+ "split": "test",
4
+ "n_images": 1091,
5
+ "method": "greedy",
6
+ "beam_width": null,
7
+ "max_len": 20,
8
+ "rnn_type": "lstm",
9
+ "BLEU-1": 0.6342,
10
+ "BLEU-2": 0.4485,
11
+ "BLEU-3": 0.3057,
12
+ "BLEU-4": 0.2073,
13
+ "METEOR": 0.4119,
14
+ "CIDEr": 0.5322,
15
+ "ROUGE-L": 0.4654,
16
+ "wall_clock_s": 6.1
17
+ }
vocab.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:30adbb8a77440e549df89caf254b77ebb0c269fdc6bec5ee3a3a79f310521c07
3
+ size 126102