OmarGamal48812 commited on
Commit
9d398b7
Β·
verified Β·
1 Parent(s): 70bd798

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +65 -74
README.md CHANGED
@@ -6,8 +6,10 @@ tags:
6
  - pytorch
7
  - resnet
8
  - attention
9
- - lstm
 
10
  - flickr8k
 
11
  - show-attend-and-tell
12
  datasets:
13
  - nlphuji/flickr8k
@@ -20,78 +22,74 @@ library_name: pytorch
20
  pipeline_tag: image-to-text
21
  ---
22
 
23
- # Flickr8k Image Captioning β€” ResNet50 + Bahdanau Attention + LSTM Decoder
24
 
25
  This model generates a natural-language description of an image. It uses a
26
  **ResNet50** spatial-feature encoder, a **Bahdanau (additive)** attention
27
- module, and an **LSTM decoder**, trained with teacher forcing and doubly
28
- stochastic regularization on the **Flickr8k** dataset (8,091 images Γ— 5
29
- captions). It is the reference architecture from
30
- [*Show, Attend and Tell* (Xu et al., 2015)](https://arxiv.org/abs/1502.03044).
 
31
 
32
  ## Test-set performance (beam search, k = 5)
33
 
 
 
 
34
  | Metric | Value |
35
  |---|---|
36
- | BLEU-1 | 0.6488 |
37
- | BLEU-2 | 0.4714 |
38
- | BLEU-3 | 0.3378 |
39
- | **BLEU-4** | **0.2403** |
40
- | METEOR | 0.4270 |
41
- | CIDEr | 0.6002 |
42
- | ROUGE-L | 0.4788 |
43
-
44
- Greedy decoding scores: BLEU-4 = 0.2073, METEOR = 0.4119, CIDEr = 0.5322.
45
 
46
- Evaluated on the held-out 1,091-image test split (image-level split β€” no
47
- captions cross train/val/test). Beam search uses length-normalized log-probs
48
- (`alpha = 0.7`) and a repetition penalty of `1.2`.
49
 
50
  ## Architecture
51
 
52
  ```
53
  Image (3, 224, 224)
54
- └─ ResNet50 (pretrained, frozen first 15 epochs, last 2 blocks fine-tuned)
55
  output: (B, 2048, 7, 7) β†’ reshape to (B, 49, 2048)
56
  └─ Bahdanau attention VΒ·tanh(W_enc(features) + W_dec(h_prev))
57
  output: context vector (B, 2048), attention weights (B, 49)
58
- └─ LSTMCell (per timestep β€” re-queries attention each step)
59
- hidden state size: 512, embedding size: 256
60
- └─ Linear β†’ vocab logits (V = 2,557)
61
  ```
62
 
63
- Total parameters: **~36 M** (28 M frozen ResNet, 8 M trainable decoder/projection).
64
 
65
  ## Training details
66
 
67
- - **Loss** β€” `CrossEntropyLoss(ignore_index=0)` plus doubly-stochastic
68
- regularization `Ξ±_c Β· ((1 βˆ’ Ξ£_t Ξ±_t)Β²).mean()` with `Ξ±_c = 1.0`
69
- - **Optimizer** β€” Adam, decoder LR `4e-4`, encoder LR `1e-5` (Phase B)
70
- - **Schedule** β€” `ReduceLROnPlateau` on val BLEU-4, `factor=0.5`,
71
- `patience=3`
72
- - **Two-phase training** β€” Phase A (15 epochs): freeze CNN, train decoder
73
- only. Phase B (10 epochs): unfreeze last 2 ResNet blocks.
74
- - **Vocabulary** β€” 2,557 tokens (frequency threshold 5), built from train
75
- captions only. Special tokens: `<pad>=0, <start>=1, <end>=2, <unk>=3`.
76
- - **Batch size** β€” 32, gradient clip 5.0
77
- - **Seed** β€” 42
78
 
79
  ## Files in this repo
80
 
81
- - `attention_lstm.pth` β€” PyTorch checkpoint (encoder + decoder state
82
- dicts, optimizer state, training config)
83
  - `vocab.pkl` β€” pickled `Vocabulary` object built from the train split
84
  - `config.json` β€” JSON copy of the training hyperparameters
85
- - `metrics_beam5.json`, `metrics_greedy.json` β€” full test-set metrics
86
 
87
  ## Usage
88
 
89
- The cleanest way to use this model is to clone the source repo so the
90
- `Vocabulary`, encoder, and decoder classes are importable:
91
-
92
  ```bash
93
- git clone https://github.com/OmarGamal488/flickr8k-image-captioning.git
94
- cd flickr8k-image-captioning
95
  uv sync
96
  ```
97
 
@@ -104,7 +102,7 @@ from src.inference import load_attention_model, caption_image
104
  from src.utils import get_device
105
 
106
  repo_id = "OmarGamal48812/flickr8k-attention-lstm"
107
- ckpt_path = hf_hub_download(repo_id=repo_id, filename="attention_lstm.pth")
108
  vocab_path = hf_hub_download(repo_id=repo_id, filename="vocab.pkl")
109
 
110
  device = get_device()
@@ -118,36 +116,25 @@ caption, beams = caption_image(
118
  method="beam", beam_width=5,
119
  )
120
  print(caption)
 
 
121
  ```
122
 
123
- For interactive use, the same repo ships a Gradio demo (`app.py`) and a
124
- FastAPI service (`api/main.py`).
125
-
126
  ## Limitations
127
 
128
- - **Small training set.** Flickr8k has only 6,000 training images, so the
129
- model often falls back to "safe" generic captions (e.g. *a dog runs through
130
- the grass*) for unfamiliar scenes.
131
- - **Vocabulary cap.** Words seen fewer than 5 times in the train split
132
- collapse to `<unk>`. Rare nouns and proper names are systematically lost.
133
- - **Domain.** Trained exclusively on Flickr8k photos (mostly people, dogs,
134
- outdoor scenes). Performance degrades on cartoons, screenshots, abstract
135
- imagery, and any scene type not represented in Flickr8k.
136
- - **Hallucinations.** Like all autoregressive captioners, the decoder can
137
- insert objects that aren't in the image when attention drifts.
138
- - **English only.** Vocabulary and grammar are entirely English Flickr8k
139
- captions.
140
-
141
- ## Intended use
142
-
143
- Educational demonstrations of the Show-Attend-Tell architecture and
144
- research baselines. Not appropriate as the only data source for
145
- accessibility tooling (alt-text generation should ideally use a model
146
- trained on a much larger dataset).
147
 
148
  ## Citation
149
 
150
- If you use this checkpoint, please credit the underlying paper:
151
 
152
  ```bibtex
153
  @inproceedings{xu2015show,
@@ -157,15 +144,19 @@ If you use this checkpoint, please credit the underlying paper:
157
  booktitle = {ICML},
158
  year = {2015}
159
  }
160
- ```
161
 
162
- and the dataset:
 
 
 
 
 
163
 
164
- ```bibtex
165
- @article{hodosh2013framing,
166
- title = {Framing Image Description as a Ranking Task: Data, Models and Evaluation Metrics},
167
- author = {Hodosh, Micah and Young, Peter and Hockenmaier, Julia},
168
- journal = {Journal of Artificial Intelligence Research},
169
- year = {2013}
170
  }
171
  ```
 
6
  - pytorch
7
  - resnet
8
  - attention
9
+ - gru
10
+ - glove
11
  - flickr8k
12
+ - flickr30k
13
  - show-attend-and-tell
14
  datasets:
15
  - nlphuji/flickr8k
 
22
  pipeline_tag: image-to-text
23
  ---
24
 
25
+ # Flickr Image Captioning β€” ResNet50 + Bahdanau Attention + GRU + GloVe
26
 
27
  This model generates a natural-language description of an image. It uses a
28
  **ResNet50** spatial-feature encoder, a **Bahdanau (additive)** attention
29
+ module, and a **GRU decoder** initialized with **GloVe 6B 300d** embeddings,
30
+ trained on the merged **Flickr8k + Flickr30k** dataset (39,874 images Γ— 5
31
+ captions). It follows the architecture from
32
+ [*Show, Attend and Tell* (Xu et al., 2015)](https://arxiv.org/abs/1502.03044)
33
+ with label smoothing, scheduled sampling, and two-phase CNN fine-tuning.
34
 
35
  ## Test-set performance (beam search, k = 5)
36
 
37
+ Evaluated on the held-out 1,873-image test split (image-level split β€” no
38
+ captions cross train/val/test).
39
+
40
  | Metric | Value |
41
  |---|---|
42
+ | BLEU-1 | 0.6859 |
43
+ | BLEU-2 | 0.5289 |
44
+ | BLEU-3 | 0.4041 |
45
+ | **BLEU-4** | **0.3093** |
46
+ | METEOR | 0.4709 |
47
+ | CIDEr | 0.7961 |
48
+ | ROUGE-L | 0.5257 |
 
 
49
 
50
+ Beam search uses length-normalized log-probs (`alpha = 0.7`) and a
51
+ repetition penalty of `1.2`.
 
52
 
53
  ## Architecture
54
 
55
  ```
56
  Image (3, 224, 224)
57
+ └─ ResNet50 (pretrained, frozen first 10 epochs, last 2 blocks fine-tuned)
58
  output: (B, 2048, 7, 7) β†’ reshape to (B, 49, 2048)
59
  └─ Bahdanau attention VΒ·tanh(W_enc(features) + W_dec(h_prev))
60
  output: context vector (B, 2048), attention weights (B, 49)
61
+ └─ GRUCell (per timestep β€” re-queries attention each step)
62
+ hidden state size: 1024, embedding size: 300 (GloVe 6B 300d)
63
+ └─ Linear β†’ vocab logits (V = 10,111)
64
  ```
65
 
66
+ Total parameters: **~37 M** (25 M frozen ResNet, 12 M trainable decoder/projection).
67
 
68
  ## Training details
69
 
70
+ - **Dataset** β€” Flickr8k + Flickr30k merged (37,000 train / 1,000 val / 1,873 test)
71
+ - **Vocabulary** β€” 10,111 tokens (frequency threshold 3), built from train
72
+ captions only. Special tokens: `<pad>=0, <start>=1, <end>=2, `<unk>=3`.
73
+ - **Loss** β€” `CrossEntropyLoss(ignore_index=0, label_smoothing=0.1)` plus
74
+ doubly-stochastic regularization `Ξ±_c Β· ((1 βˆ’ Ξ£_t Ξ±_t)Β²).mean()` with `Ξ±_c = 1.0`
75
+ - **Optimizer** β€” Adam, decoder LR `3.2e-3`, encoder LR `8e-5` (Phase B)
76
+ - **Schedule** β€” `ReduceLROnPlateau` on val BLEU-4, `factor=0.5`, `patience=3`
77
+ - **Two-phase training** β€” Phase A (epochs 1–10): freeze CNN. Phase B (epochs 11–35): unfreeze last 2 ResNet blocks.
78
+ - **Scheduled sampling** β€” linear ramp from 0 to max 0.25 over training epochs
79
+ - **Batch size** β€” 256, gradient clip 5.0, seed 42
 
80
 
81
  ## Files in this repo
82
 
83
+ - `attention_gru_glove.pth` β€” PyTorch checkpoint (encoder + decoder state dicts, config)
 
84
  - `vocab.pkl` β€” pickled `Vocabulary` object built from the train split
85
  - `config.json` β€” JSON copy of the training hyperparameters
86
+ - `metrics_beam5.json` β€” full test-set metrics (beam search k=5)
87
 
88
  ## Usage
89
 
 
 
 
90
  ```bash
91
+ git clone https://github.com/OmarGamal488/flickr-image-captioning.git
92
+ cd flickr-image-captioning
93
  uv sync
94
  ```
95
 
 
102
  from src.utils import get_device
103
 
104
  repo_id = "OmarGamal48812/flickr8k-attention-lstm"
105
+ ckpt_path = hf_hub_download(repo_id=repo_id, filename="attention_gru_glove.pth")
106
  vocab_path = hf_hub_download(repo_id=repo_id, filename="vocab.pkl")
107
 
108
  device = get_device()
 
116
  method="beam", beam_width=5,
117
  )
118
  print(caption)
119
+ for b in beams[:3]:
120
+ print(f" {b.score:+.3f} {b.caption}")
121
  ```
122
 
 
 
 
123
  ## Limitations
124
 
125
+ - **Domain.** Trained on Flickr8k + Flickr30k photos (mostly people, dogs,
126
+ outdoor scenes). Performance degrades on cartoons, screenshots, and abstract imagery.
127
+ - **Safe-word bias.** Only 8.8% of the 10,111-word vocabulary is used at inference β€”
128
+ the decoder converges on template phrases like *"a man in a white shirt is standing"*.
129
+ - **No object counting.** The attention context vector collapses object count β€”
130
+ the model often says "a dog" when the image shows two dogs.
131
+ - **Hallucinations.** The decoder can insert objects not in the image when visual
132
+ evidence is weak and the language-model prior takes over.
133
+ - **English only.** Vocabulary and grammar are entirely from English Flickr captions.
 
 
 
 
 
 
 
 
 
 
134
 
135
  ## Citation
136
 
137
+ If you use this checkpoint, please cite the three papers this work builds on:
138
 
139
  ```bibtex
140
  @inproceedings{xu2015show,
 
144
  booktitle = {ICML},
145
  year = {2015}
146
  }
 
147
 
148
+ @article{bahdanau2014neural,
149
+ title = {Neural Machine Translation by Jointly Learning to Align and Translate},
150
+ author = {Bahdanau, Dzmitry and Cho, Kyunghyun and Bengio, Yoshua},
151
+ journal = {arXiv preprint arXiv:1409.0473},
152
+ year = {2014}
153
+ }
154
 
155
+ @inproceedings{selvaraju2017gradcam,
156
+ title = {Grad-{CAM}: Visual Explanations from Deep Networks via Gradient-based Localization},
157
+ author = {Selvaraju, Ramprasaath R. and Cogswell, Michael and Das, Abhishek and
158
+ Vedantam, Ramakrishna and Parikh, Devi and Batra, Dhruv},
159
+ booktitle = {ICCV},
160
+ year = {2017}
161
  }
162
  ```