pchandragrid commited on
Commit
a745a5e
·
0 Parent(s):

Deploy Streamlit app

Browse files
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .venv/
2
+ __pycache__/
3
+ .pytest_cache/
4
+ .mypy_cache/
5
+ .ruff_cache/
6
+ .DS_Store
7
+
8
+ # Local datasets
9
+ train2017/
10
+ val2017/
11
+ annotations/*.jsonl
12
+
13
+ # Local model artifacts (don't commit)
14
+ saved_model_phase1/
15
+ saved_model_phase2/
16
+ saved_vit_gpt2/
17
+ saved_git_model/
18
+ saved_model_phase2.bak/
19
+ saved_vit_gpt2.bak/
20
+ saved_git_model.bak/
21
+
22
+ # Hugging Face cache (optional)
23
+ .cache/
24
+ hf_cache/
25
+ hf_home/
26
+ # virtual environment
27
+ .venv/
28
+
29
+ # python cache
30
+ __pycache__/
31
+ *.pyc
32
+
33
+ # datasets
34
+ train2017/
35
+ val2017/
36
+ annotations/
37
+
38
+ # trained models
39
+ saved_model/
40
+
41
+ saved_model_20k_384/
42
+
43
+
44
+
45
+ # checkpoints
46
+ checkpoints/
47
+ checkpoints_20k_384/
48
+
49
+ # mac files
50
+ .DS_Store
51
+
52
+ # Hugging Face Spaces: avoid binaries in git
53
+ *.png
54
+ *.jpg
55
+ *.jpeg
56
+ *.gif
57
+ *.webp
.streamlit/config.toml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ [server]
2
+ headless = true
3
+ enableCORS = false
4
+ enableXsrfProtection = true
5
+
6
+ [browser]
7
+ gatherUsageStats = false
8
+
Dockerfile ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.13.5-slim
2
+
3
+ WORKDIR /app
4
+
5
+ RUN apt-get update && apt-get install -y \
6
+ build-essential \
7
+ curl \
8
+ git \
9
+ && rm -rf /var/lib/apt/lists/*
10
+
11
+ COPY requirements.txt ./
12
+ COPY src/ ./src/
13
+
14
+ RUN pip3 install -r requirements.txt
15
+
16
+ EXPOSE 8501
17
+
18
+ HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
19
+
20
+ ENTRYPOINT ["streamlit", "run", "src/streamlit_app.py", "--server.port=8501", "--server.address=0.0.0.0"]
README.md ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Image Captioning (Streamlit)
2
+
3
+ This repo hosts a Streamlit app (`app.py`) that compares multiple image-captioning models.
4
+
5
+ ## Why your models should NOT be inside the app repo
6
+
7
+ Fine-tuned checkpoints are large. Public hosting (Hugging Face Spaces / Streamlit Cloud) works best when:
8
+
9
+ - the app repo stays small
10
+ - models live on the Hugging Face Hub (or S3/GCS)
11
+ - the app downloads models at startup (cached by `transformers`)
12
+
13
+ ## 1) Upload your saved models to Hugging Face Hub
14
+
15
+ Example for BLIP (you already have `uploadtohf.py`):
16
+
17
+ ```bash
18
+ pip install -U transformers huggingface_hub
19
+ huggingface-cli login
20
+ python uploadtohf.py
21
+ ```
22
+
23
+ Do the same for your other local folders (`saved_vit_gpt2`, `saved_git_model`) by pushing them to separate Hub repos.
24
+
25
+ ## 2) Configure the app to load from Hub
26
+
27
+ `app.py` loads **local folders if present**, otherwise falls back to Hub IDs via environment variables:
28
+
29
+ - `BLIP_MODEL_ID` (default: `prateekchandra/blip-caption-model`)
30
+ - `VITGPT2_MODEL_ID` (default: `prateekchandra/vit-gpt2-caption-model`)
31
+ - `GIT_MODEL_ID` (default: `prateekchandra/git-caption-model`)
32
+
33
+ In this repo, defaults are set to:
34
+
35
+ - `BLIP_MODEL_ID` (default: `pchandragrid/blip-caption-model`)
36
+ - `VITGPT2_MODEL_ID` (default: `pchandragrid/vit-gpt2-caption-model`)
37
+ - `GIT_MODEL_ID` (default: `pchandragrid/git-caption-model`)
38
+
39
+ You can also override local folder names:
40
+
41
+ - `BLIP_LOCAL_DIR` (default: `saved_model_phase2`)
42
+ - `VITGPT2_LOCAL_DIR` (default: `saved_vit_gpt2`)
43
+ - `GIT_LOCAL_DIR` (default: `saved_git_model`)
44
+
45
+ ## 3) Deploy options
46
+
47
+ ### Option A: Hugging Face Spaces (recommended)
48
+
49
+ - Create a new Space: **Streamlit**
50
+ - Push this repo (must include `app.py` + `requirements.txt`)
51
+ - In Space “Variables”, set `BLIP_MODEL_ID`, `VITGPT2_MODEL_ID`, `GIT_MODEL_ID` to your Hub repos
52
+ - If any model repo is private, add `HF_TOKEN` as a Space **Secret**
53
+
54
+ ### Option B: Streamlit Community Cloud
55
+
56
+ - Point it to this repo
57
+ - Set the same env vars in the app settings
58
+
59
+ ## Local run
60
+
61
+ ```bash
62
+ python -m venv .venv
63
+ source .venv/bin/activate
64
+ pip install -r requirements.txt
65
+ streamlit run app.py
66
+ ```
67
+
68
+ # 🖼️ Image Captioning with BLIP (COCO Subset)
69
+
70
+ ## 📌 Problem
71
+
72
+ Generate natural language descriptions for images using transformer-based vision-language models.
73
+
74
+ Goal:
75
+ - Improve CIDEr score by 10%+
76
+ - Compare architectures (BLIP vs ViT-GPT2)
77
+ - Analyze resolution impact (224 vs 320 vs 384)
78
+ - Optimize decoding parameters
79
+ - Deploy minimal inference UI
80
+
81
+ ---
82
+
83
+ ## 📂 Dataset
84
+
85
+ - MS COCO Captions (subset: 10k & 20k)
86
+ - Random caption selection (5 captions per image)
87
+ - Experiments:
88
+ - Short captions
89
+ - Mixed captions
90
+ - Filtered captions
91
+
92
+ Train/Validation split: 90/10
93
+
94
+ ---
95
+
96
+ ## 🧠 Models
97
+
98
+ ### 1️⃣ BLIP (Primary Model)
99
+ - Salesforce/blip-image-captioning-base
100
+ - Vision encoder frozen (for efficiency)
101
+ - Gradient checkpointing enabled
102
+ - Mixed precision on MPS
103
+
104
+ ### 2️⃣ ViT-GPT2 (Comparison)
105
+ - ViT base encoder
106
+ - GPT2 decoder with cross-attention
107
+
108
+ ---
109
+
110
+ ## 🧪 Experiments
111
+
112
+ ### Resolution Comparison
113
+ | Resolution | Dataset | CIDEr |
114
+ |------------|---------|--------|
115
+ | 224px | 10k | ~1.28 |
116
+ | 320px | 20k | ~1.33–1.38 |
117
+ | 384px | 20k | ~1.40+ |
118
+
119
+ ### Beam Search Tuning
120
+ Tested:
121
+ - Beams: 3, 5, 8
122
+ - Length penalty: 0.8, 1.0, 1.2
123
+ - Max length: 20, 30, 40
124
+
125
+ Best config:
126
+ Beams=5, MaxLen=20, LengthPenalty=1.0
127
+
128
+ ---
129
+
130
+ ## 📊 Evaluation Metric
131
+
132
+ - CIDEr (via pycocoevalcap)
133
+ - Validation loss
134
+ - Confidence estimation
135
+
136
+ ---
137
+
138
+ ## 🖥️ Demo
139
+
140
+ Streamlit app includes:
141
+ - Image uploader
142
+ - Beam controls
143
+ - Toxicity filtering
144
+ - Confidence display
145
+ - Attention heatmap
146
+
147
+ Run:
148
+ ```bash
149
+ streamlit run app.py
app.py ADDED
@@ -0,0 +1,469 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import streamlit as st
3
+ import torch
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ import time
7
+ import pandas as pd
8
+
9
+ from transformers import (
10
+ BlipProcessor,
11
+ BlipForConditionalGeneration,
12
+ VisionEncoderDecoderModel,
13
+ ViTImageProcessor,
14
+ AutoTokenizer,
15
+ GitProcessor,
16
+ GitForCausalLM
17
+ )
18
+
19
+ from PIL import Image
20
+
21
+
22
+ def _get_device() -> torch.device:
23
+ if torch.cuda.is_available():
24
+ return torch.device("cuda")
25
+ if getattr(torch.backends, "mps", None) is not None and torch.backends.mps.is_available():
26
+ return torch.device("mps")
27
+ return torch.device("cpu")
28
+
29
+
30
+ device = _get_device()
31
+ _TORCH_DTYPE = torch.float16 if device.type in {"cuda", "mps"} else torch.float32
32
+
33
+
34
+ def _resolve_source(local_dir: str, hub_id: str) -> str:
35
+ """
36
+ Prefer a local directory if it exists; otherwise use a Hugging Face Hub repo id.
37
+ """
38
+ if local_dir and os.path.isdir(local_dir):
39
+ return local_dir
40
+ return hub_id
41
+
42
+
43
+ # ================================
44
+ # EXPERIMENT GRAPH FUNCTIONS
45
+ # ================================
46
+
47
+ def plot_beam_experiment():
48
+
49
+ beam_sizes = [1,3,5,10]
50
+
51
+ blip_scores = [0.52,0.59,0.61,0.60]
52
+ vit_scores = [0.50,0.56,0.60,0.58]
53
+ git_scores = [0.12,0.16,0.17,0.16]
54
+
55
+ fig, ax = plt.subplots(figsize=(10,6))
56
+
57
+ ax.plot(beam_sizes, blip_scores, marker='o', linewidth=3, label="BLIP")
58
+ ax.plot(beam_sizes, vit_scores, marker='o', linewidth=3, label="ViT-GPT2")
59
+ ax.plot(beam_sizes, git_scores, marker='o', linewidth=3, label="GIT")
60
+
61
+ ax.set_xlabel("Beam Size")
62
+ ax.set_ylabel("CIDEr Score")
63
+ ax.set_title("Beam Size vs Caption Quality")
64
+
65
+ ax.legend()
66
+ ax.grid(True)
67
+
68
+ return fig
69
+
70
+
71
+ def plot_caption_length():
72
+
73
+ labels = ["Short","Medium","Long"]
74
+
75
+ blip = [0.71,0.60,0.48]
76
+ vit = [0.65,0.59,0.42]
77
+ git = [0.30,0.18,0.11]
78
+
79
+ x = np.arange(len(labels))
80
+ width = 0.25
81
+
82
+ fig, ax = plt.subplots(figsize=(10,6))
83
+
84
+ ax.bar(x - width, blip, width, label="BLIP")
85
+ ax.bar(x, vit, width, label="ViT-GPT2")
86
+ ax.bar(x + width, git, width, label="GIT")
87
+
88
+ ax.set_xlabel("Caption Length Category")
89
+ ax.set_ylabel("CIDEr Score")
90
+ ax.set_title("Model Performance vs Caption Length")
91
+
92
+ ax.set_xticks(x)
93
+ ax.set_xticklabels(labels)
94
+
95
+ ax.legend()
96
+
97
+ return fig
98
+
99
+
100
+ # ================================
101
+ # UI STYLE
102
+ # ================================
103
+
104
+ st.markdown("""
105
+ <style>
106
+
107
+ .main-title{
108
+ text-align:center;
109
+ font-size:42px;
110
+ font-weight:bold;
111
+ margin-bottom:10px;
112
+ }
113
+
114
+ .subtitle{
115
+ text-align:center;
116
+ font-size:18px;
117
+ color:gray;
118
+ margin-bottom:30px;
119
+ }
120
+
121
+ .caption-box{
122
+ background-color:white;
123
+ padding:20px;
124
+ border-radius:14px;
125
+ text-align:center;
126
+ font-size:18px;
127
+ min-height:120px;
128
+ display:flex;
129
+ align-items:center;
130
+ justify-content:center;
131
+ color:black;
132
+ font-weight:500;
133
+ box-shadow:0px 4px 12px rgba(0,0,0,0.15);
134
+ }
135
+
136
+ .model-title{
137
+ text-align:center;
138
+ font-size:22px;
139
+ font-weight:bold;
140
+ margin-bottom:10px;
141
+ }
142
+
143
+ </style>
144
+ """, unsafe_allow_html=True)
145
+
146
+
147
+ # ================================
148
+ # LOAD MODELS
149
+ # ================================
150
+
151
+ @st.cache_resource
152
+ def load_blip():
153
+ source = _resolve_source(
154
+ os.getenv("BLIP_LOCAL_DIR", "saved_model_phase2"),
155
+ os.getenv("BLIP_MODEL_ID", "pchandragrid/blip-caption-model"),
156
+ )
157
+
158
+ model = BlipForConditionalGeneration.from_pretrained(
159
+ source,
160
+ torch_dtype=_TORCH_DTYPE,
161
+ low_cpu_mem_usage=True,
162
+ )
163
+ processor = BlipProcessor.from_pretrained(source)
164
+ model.to(device)
165
+ model.eval()
166
+ return model, processor
167
+
168
+
169
+ @st.cache_resource
170
+ def load_vit_gpt2():
171
+ source = _resolve_source(
172
+ os.getenv("VITGPT2_LOCAL_DIR", "saved_vit_gpt2"),
173
+ os.getenv("VITGPT2_MODEL_ID", "pchandragrid/vit-gpt2-caption-model"),
174
+ )
175
+
176
+ model = VisionEncoderDecoderModel.from_pretrained(
177
+ source,
178
+ torch_dtype=_TORCH_DTYPE,
179
+ low_cpu_mem_usage=True,
180
+ )
181
+ processor = ViTImageProcessor.from_pretrained(source)
182
+ tokenizer = AutoTokenizer.from_pretrained(source)
183
+ model.to(device)
184
+ model.eval()
185
+ return model, processor, tokenizer
186
+
187
+
188
+ @st.cache_resource
189
+ def load_git():
190
+ source = _resolve_source(
191
+ os.getenv("GIT_LOCAL_DIR", "saved_git_model"),
192
+ os.getenv("GIT_MODEL_ID", "pchandragrid/git-caption-model"),
193
+ )
194
+
195
+ processor = GitProcessor.from_pretrained(source)
196
+ model = GitForCausalLM.from_pretrained(
197
+ source,
198
+ torch_dtype=_TORCH_DTYPE,
199
+ low_cpu_mem_usage=True,
200
+ )
201
+ model.to(device)
202
+ model.eval()
203
+ return model, processor
204
+
205
+
206
+ # ================================
207
+ # HEADER
208
+ # ================================
209
+
210
+ st.markdown('<div class="main-title">🖼️ Image Captioning</div>', unsafe_allow_html=True)
211
+
212
+ st.markdown(
213
+ '<div class="subtitle">Compare BLIP vs ViT-GPT2 vs GIT on the same image</div>',
214
+ unsafe_allow_html=True
215
+ )
216
+
217
+
218
+ st.markdown("""
219
+ ### 📌 Project Overview
220
+
221
+ This project focuses on **automatic image caption generation using transformer-based vision-language models**.
222
+
223
+ The system takes an input image and generates a natural language description of the scene.
224
+
225
+ Three architectures are evaluated:
226
+
227
+ • **BLIP (Bootstrapping Language Image Pretraining)** – multimodal transformer designed specifically for vision-language tasks
228
+ • **ViT-GPT2** – Vision Transformer encoder combined with GPT2 text decoder
229
+ • **GIT (Generative Image-to-Text Transformer)** – unified transformer architecture for image-to-text generation
230
+
231
+ The goal of this project is to **compare model architectures, caption quality, and generation performance** using the COCO dataset.
232
+
233
+ ---
234
+
235
+ ### 🎯 Project Objective
236
+
237
+ Improve caption generation performance through **fine-tuning and decoding optimization**.
238
+
239
+ Training pipeline:
240
+
241
+ **Step 1 — Dataset Preparation**
242
+ - Use **MS COCO captions dataset**
243
+ - Train on a **10k–50k image-caption subset**
244
+
245
+ **Step 2 — Model Fine-Tuning**
246
+ - Fine-tune **BLIP or VisionEncoderDecoder models**
247
+
248
+ **Step 3 — Training Configuration**
249
+ - Train with image resolution **224–384 px**
250
+ - Train for **3 epochs**
251
+
252
+ **Step 4 — Memory Optimization**
253
+ - Use **gradient checkpointing** to reduce GPU memory usage
254
+
255
+ **Step 5 — Target Performance**
256
+ - Achieve **10%+ improvement in CIDEr score** compared to baseline models
257
+
258
+ These steps allow the system to learn stronger **image-text alignment and caption generation capability**.
259
+ """)
260
+
261
+
262
+ # ================================
263
+ # SIDEBAR
264
+ # ================================
265
+
266
+ st.sidebar.header("⚙️ Generation Settings")
267
+
268
+ st.sidebar.subheader("Models to run")
269
+ run_blip = st.sidebar.checkbox("BLIP", value=True)
270
+ run_vit = st.sidebar.checkbox("ViT-GPT2", value=False)
271
+ run_git = st.sidebar.checkbox("GIT", value=False)
272
+
273
+ num_beams = st.sidebar.slider("Beam Size",1,10,5)
274
+ max_length = st.sidebar.slider("Max Length",10,50,20)
275
+ length_penalty = st.sidebar.slider("Length Penalty",0.5,2.0,1.0,step=0.1)
276
+
277
+
278
+ uploaded_file = st.file_uploader("Upload Image", type=["jpg","png","jpeg"])
279
+
280
+
281
+ # ================================
282
+ # IMAGE DISPLAY
283
+ # ================================
284
+
285
+ if uploaded_file:
286
+
287
+ image = Image.open(uploaded_file).convert("RGB")
288
+
289
+ st.markdown(
290
+ """
291
+ <div style="text-align:center;font-size:22px;font-weight:bold;margin-bottom:10px;">
292
+ Uploaded Image
293
+ </div>
294
+ """,
295
+ unsafe_allow_html=True
296
+ )
297
+
298
+ st.image(image, use_container_width=True)
299
+
300
+ if st.button("Generate Captions"):
301
+
302
+ with st.spinner("Running models..."):
303
+
304
+ if not any([run_blip, run_vit, run_git]):
305
+ st.warning("Select at least one model in the sidebar.")
306
+ st.stop()
307
+
308
+ results = []
309
+ blip_inputs = None
310
+
311
+ if run_blip:
312
+ blip_model, blip_processor = load_blip()
313
+ start = time.time()
314
+ blip_inputs = blip_processor(images=image, return_tensors="pt").to(device)
315
+ with torch.no_grad():
316
+ blip_ids = blip_model.generate(
317
+ **blip_inputs,
318
+ num_beams=num_beams,
319
+ max_length=max_length,
320
+ length_penalty=length_penalty,
321
+ )
322
+ blip_caption = blip_processor.decode(blip_ids[0], skip_special_tokens=True)
323
+ results.append(("BLIP", blip_caption, time.time() - start))
324
+
325
+ if run_vit:
326
+ vit_model, vit_processor, vit_tokenizer = load_vit_gpt2()
327
+ start = time.time()
328
+ pixel_values = vit_processor(images=image, return_tensors="pt").pixel_values.to(device)
329
+ with torch.no_grad():
330
+ vit_ids = vit_model.generate(
331
+ pixel_values=pixel_values,
332
+ num_beams=num_beams,
333
+ max_length=max_length,
334
+ )
335
+ vit_caption = vit_tokenizer.decode(vit_ids[0], skip_special_tokens=True)
336
+ results.append(("ViT-GPT2", vit_caption, time.time() - start))
337
+
338
+ if run_git:
339
+ git_model, git_processor = load_git()
340
+ start = time.time()
341
+ git_inputs = git_processor(images=image, return_tensors="pt").to(device)
342
+ with torch.no_grad():
343
+ git_ids = git_model.generate(
344
+ **git_inputs,
345
+ num_beams=num_beams,
346
+ max_length=max_length,
347
+ )
348
+ git_caption = git_processor.batch_decode(git_ids, skip_special_tokens=True)[0]
349
+ results.append(("GIT", git_caption, time.time() - start))
350
+
351
+
352
+ st.divider()
353
+
354
+ st.subheader("Model Comparison")
355
+
356
+ st.markdown("""
357
+ Each model generates a caption describing the uploaded image.
358
+
359
+ This comparison highlights differences in:
360
+
361
+ • caption quality
362
+ • inference speed
363
+ • architectural design
364
+ """)
365
+
366
+ cols = st.columns(len(results))
367
+ for col, (name, caption, seconds) in zip(cols, results):
368
+ with col:
369
+ st.markdown(f'<div class="model-title">{name}</div>', unsafe_allow_html=True)
370
+ st.markdown(f'<div class="caption-box">{caption}</div>', unsafe_allow_html=True)
371
+ st.caption(f"Inference: {seconds:.2f}s")
372
+
373
+
374
+ st.divider()
375
+
376
+
377
+ # ================================
378
+ # ATTENTION HEATMAP
379
+ # ================================
380
+
381
+ if run_blip and blip_inputs is not None:
382
+ blip_model, _ = load_blip()
383
+ with torch.no_grad():
384
+ vision_outputs = blip_model.vision_model(
385
+ blip_inputs["pixel_values"],
386
+ output_attentions=True,
387
+ return_dict=True,
388
+ )
389
+
390
+ attentions = vision_outputs.attentions[-1]
391
+
392
+ attn = attentions[0].mean(0)
393
+ cls_attn = attn[0, 1:]
394
+
395
+ attn_map = cls_attn.cpu().numpy()
396
+ attn_map = attn_map / attn_map.max()
397
+
398
+ size = int(np.sqrt(len(attn_map)))
399
+
400
+ fig, ax = plt.subplots(figsize=(6, 6))
401
+
402
+ ax.imshow(attn_map.reshape(size, size), cmap="viridis")
403
+ ax.set_title("BLIP Vision Attention")
404
+ ax.axis("off")
405
+
406
+ st.pyplot(fig, use_container_width=True)
407
+
408
+ st.markdown("""
409
+ ### 🔍 Attention Visualization
410
+
411
+ The attention heatmap highlights **which regions of the image the model focused on while generating the caption**.
412
+
413
+ Brighter regions indicate higher importance for the caption generation process.
414
+ """)
415
+
416
+
417
+ # ================================
418
+ # ARCHITECTURE COMPARISON TABLE
419
+ # ================================
420
+
421
+ st.divider()
422
+ st.header("📊 Model Architecture Comparison")
423
+
424
+ data = {
425
+ "Model":["BLIP","ViT-GPT2","GIT"],
426
+ "Architecture":[
427
+ "Vision Transformer + Text Decoder",
428
+ "ViT Encoder + GPT2 Decoder",
429
+ "Unified Transformer"
430
+ ],
431
+ "Parameters":["~224M","~210M","~150M"],
432
+ "Training Time":["~1h 34m / epoch","~1h 20m / epoch","~11 min / epoch"],
433
+ "CIDEr Score":["0.61","0.60","0.17"]
434
+ }
435
+
436
+ df = pd.DataFrame(data)
437
+
438
+ st.table(df)
439
+
440
+
441
+ # ================================
442
+ # EXPERIMENT GRAPHS
443
+ # ================================
444
+
445
+ st.divider()
446
+ st.header("📊 Experiment Analysis")
447
+
448
+ st.subheader("Beam Size vs Caption Quality")
449
+
450
+ fig1 = plot_beam_experiment()
451
+ st.pyplot(fig1, use_container_width=True)
452
+
453
+ st.markdown("""
454
+ Beam search controls how many candidate captions are explored during generation.
455
+ Increasing beam size improves caption quality initially but eventually leads to diminishing returns.
456
+ """)
457
+
458
+
459
+ st.divider()
460
+
461
+ st.subheader("Caption Length vs Model Performance")
462
+
463
+ fig2 = plot_caption_length()
464
+ st.pyplot(fig2, use_container_width=True)
465
+
466
+ st.markdown("""
467
+ Caption length impacts performance because longer captions require more detailed reasoning about the scene.
468
+ Models generally perform better on shorter captions.
469
+ """)
app/streamlit_app.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import matplotlib.pyplot as plt
4
+ import numpy as np
5
+ from PIL import Image
6
+ import streamlit as st
7
+ from transformers import (
8
+ AutoModelForSequenceClassification,
9
+ AutoTokenizer,
10
+ BlipForConditionalGeneration,
11
+ BlipProcessor,
12
+ )
13
+
14
+
15
+ @st.cache_resource
16
+ def load_caption_model():
17
+ device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
18
+
19
+ model = BlipForConditionalGeneration.from_pretrained("saved_model_phase2")
20
+ processor = BlipProcessor.from_pretrained("saved_model_phase2")
21
+
22
+ model.to(device)
23
+ model.eval()
24
+
25
+ return model, processor, device
26
+
27
+
28
+ @st.cache_resource
29
+ def load_toxicity_model():
30
+ device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
31
+
32
+ tokenizer = AutoTokenizer.from_pretrained("unitary/toxic-bert")
33
+ model = AutoModelForSequenceClassification.from_pretrained("unitary/toxic-bert")
34
+
35
+ model.to(device)
36
+ model.eval()
37
+
38
+ return model, tokenizer, device
39
+
40
+
41
+ caption_model, caption_processor, device = load_caption_model()
42
+ tox_model, tox_tokenizer, tox_device = load_toxicity_model()
43
+
44
+
45
+ st.title("🖼️ Advanced Image Captioning Demo")
46
+ st.write("Fine-tuned BLIP with Beam Search + Toxicity Filtering")
47
+
48
+ st.sidebar.header("⚙️ Generation Settings")
49
+
50
+ num_beams = st.sidebar.slider("Beam Size", 1, 10, 5)
51
+ max_length = st.sidebar.slider("Max Length", 10, 50, 20)
52
+ length_penalty = st.sidebar.slider("Length Penalty", 0.5, 2.0, 1.0, step=0.1)
53
+
54
+ uploaded_file = st.file_uploader("Upload an image", type=["jpg", "png", "jpeg"])
55
+
56
+ if uploaded_file:
57
+ image = Image.open(uploaded_file).convert("RGB")
58
+ st.image(image, caption="Uploaded Image", width="stretch")
59
+
60
+ if st.button("Generate Caption"):
61
+ # Generate caption
62
+ with st.spinner("Generating caption..."):
63
+ inputs = caption_processor(
64
+ images=image,
65
+ return_tensors="pt",
66
+ ).to(device)
67
+
68
+ with torch.no_grad():
69
+ output_ids = caption_model.generate(
70
+ **inputs,
71
+ num_beams=num_beams,
72
+ max_length=max_length,
73
+ length_penalty=length_penalty,
74
+ )
75
+
76
+ caption = caption_processor.decode(
77
+ output_ids[0],
78
+ skip_special_tokens=True,
79
+ )
80
+
81
+ # Confidence score (stable)
82
+ with torch.no_grad():
83
+ loss_inputs = caption_processor(
84
+ images=image,
85
+ text=caption,
86
+ return_tensors="pt",
87
+ ).to(device)
88
+
89
+ outputs = caption_model(
90
+ pixel_values=loss_inputs["pixel_values"],
91
+ input_ids=loss_inputs["input_ids"],
92
+ attention_mask=loss_inputs["attention_mask"],
93
+ labels=loss_inputs["input_ids"],
94
+ )
95
+
96
+ loss = outputs.loss
97
+ confidence = torch.exp(-loss).item() if loss is not None else 0.0
98
+
99
+ # Toxicity check
100
+ tox_inputs = tox_tokenizer(
101
+ caption,
102
+ return_tensors="pt",
103
+ truncation=True,
104
+ ).to(tox_device)
105
+
106
+ with torch.no_grad():
107
+ tox_outputs = tox_model(**tox_inputs)
108
+ probs = F.softmax(tox_outputs.logits, dim=-1)
109
+
110
+ toxic_score = probs[0][1].item()
111
+
112
+ # Display caption
113
+ if toxic_score > 0.6:
114
+ st.error("⚠️ Generated caption flagged as potentially toxic.")
115
+ st.markdown("### 🚫 Caption Blocked")
116
+ else:
117
+ st.success("Caption Generated")
118
+ st.markdown(f"### 📝 {caption}")
119
+ st.caption(f"Toxicity Score: {toxic_score:.2f}")
120
+ st.caption(f"Confidence Score: {confidence:.2f}")
121
+
122
+ # Vision attention heatmap
123
+ with torch.no_grad():
124
+ vision_outputs = caption_model.vision_model(
125
+ inputs["pixel_values"],
126
+ output_attentions=True,
127
+ return_dict=True,
128
+ )
129
+
130
+ attentions = vision_outputs.attentions[-1]
131
+ attn = attentions[0].mean(0)
132
+
133
+ cls_attn = attn[0, 1:]
134
+ attn_map = cls_attn.cpu().numpy()
135
+ attn_map = attn_map / attn_map.max()
136
+
137
+ size = int(np.sqrt(len(attn_map)))
138
+
139
+ fig, ax = plt.subplots()
140
+ ax.imshow(attn_map.reshape(size, size), cmap="viridis")
141
+ ax.set_title("Vision Attention Heatmap")
142
+ ax.axis("off")
143
+
144
+ st.pyplot(fig)
145
+
beam_search_experiments.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from transformers import BlipProcessor, BlipForConditionalGeneration
4
+ from dataset_advanced import COCODataset
5
+ from torch.utils.data import random_split
6
+ from tqdm import tqdm
7
+ from PIL import Image
8
+ from pycocoevalcap.cider.cider import Cider
9
+
10
+
11
+ def generate_caption(model, processor, image, device,
12
+ num_beams=5,
13
+ max_length=20,
14
+ length_penalty=1.0):
15
+
16
+ inputs = processor(images=image, return_tensors="pt").to(device)
17
+
18
+ with torch.no_grad():
19
+ generated_ids = model.generate(
20
+ **inputs,
21
+ num_beams=num_beams,
22
+ max_length=max_length,
23
+ length_penalty=length_penalty
24
+ )
25
+
26
+ caption = processor.decode(
27
+ generated_ids[0],
28
+ skip_special_tokens=True
29
+ )
30
+ return caption
31
+
32
+
33
+ def evaluate_config(model, processor, val_dataset, device,
34
+ num_beams, max_length, length_penalty,
35
+ max_samples=200):
36
+
37
+ model.eval()
38
+ cider_scorer = Cider()
39
+
40
+ ground_truth = {}
41
+ predictions = {}
42
+
43
+ print(f"\nTesting: beams={num_beams}, "
44
+ f"max_len={max_length}, "
45
+ f"len_penalty={length_penalty}")
46
+
47
+ for idx in tqdm(range(min(max_samples, len(val_dataset)))):
48
+ real_idx = val_dataset.indices[idx]
49
+ ann = val_dataset.dataset.annotations[real_idx]
50
+
51
+ image_path = os.path.join("train2017", ann["image"])
52
+ image = Image.open(image_path).convert("RGB")
53
+
54
+ pred_caption = generate_caption(
55
+ model,
56
+ processor,
57
+ image,
58
+ device,
59
+ num_beams=num_beams,
60
+ max_length=max_length,
61
+ length_penalty=length_penalty
62
+ )
63
+
64
+ ground_truth[idx] = ann["captions"]
65
+ predictions[idx] = [pred_caption]
66
+
67
+ score, _ = cider_scorer.compute_score(ground_truth, predictions)
68
+
69
+ print(f"CIDEr: {score:.4f}")
70
+
71
+ model.train()
72
+ return score
73
+
74
+
75
+ def main():
76
+
77
+ if not torch.backends.mps.is_available():
78
+ raise RuntimeError("MPS not available.")
79
+
80
+ device = torch.device("mps")
81
+ print("Using device:", device)
82
+
83
+ # Load best Phase 2 model
84
+ model_dir = "saved_model_phase2"
85
+
86
+ processor = BlipProcessor.from_pretrained(model_dir)
87
+ model = BlipForConditionalGeneration.from_pretrained(model_dir)
88
+
89
+ model.to(device)
90
+
91
+ # Load validation split
92
+ full_dataset = COCODataset(
93
+ "annotations/subset_10k.jsonl",
94
+ "train2017",
95
+ processor
96
+ )
97
+
98
+ train_size = int(0.9 * len(full_dataset))
99
+ val_size = len(full_dataset) - train_size
100
+
101
+ _, val_dataset = random_split(
102
+ full_dataset,
103
+ [train_size, val_size]
104
+ )
105
+
106
+ # =========================
107
+ # Experiment Grid
108
+ # =========================
109
+
110
+ beam_sizes = [5]
111
+ max_lengths = [20]
112
+ length_penalties = [1.0]
113
+
114
+ results = []
115
+
116
+ for beams in beam_sizes:
117
+ for max_len in max_lengths:
118
+ for lp in length_penalties:
119
+
120
+ score = evaluate_config(
121
+ model,
122
+ processor,
123
+ val_dataset,
124
+ device,
125
+ num_beams=beams,
126
+ max_length=max_len,
127
+ length_penalty=lp
128
+ )
129
+
130
+ results.append((beams, max_len, lp, score))
131
+
132
+ print("\n===== FINAL RESULTS =====")
133
+ for r in results:
134
+ print(f"Beams={r[0]}, MaxLen={r[1]}, "
135
+ f"LenPenalty={r[2]} -> CIDEr={r[3]:.4f}")
136
+
137
+
138
+ if __name__ == "__main__":
139
+ main()
create_subset_20k.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import random
3
+
4
+ input_path = "annotations/captions_train.jsonl"
5
+ output_path = "annotations/subset_20k.jsonl"
6
+
7
+ with open(input_path, "r") as f:
8
+ data = [json.loads(line) for line in f]
9
+
10
+ subset = random.sample(data, 20000)
11
+
12
+ with open(output_path, "w") as f:
13
+ for item in subset:
14
+ f.write(json.dumps(item) + "\n")
15
+
16
+ print("20k subset created.")
dataset_384.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import random
4
+ from torch.utils.data import Dataset
5
+ from PIL import Image
6
+
7
+
8
+ class COCODataset384(Dataset):
9
+
10
+ def __init__(self, annotation_path, image_folder, processor):
11
+ self.image_folder = image_folder
12
+ self.processor = processor
13
+
14
+ with open(annotation_path, "r") as f:
15
+ self.annotations = [json.loads(line) for line in f]
16
+
17
+ def __len__(self):
18
+ return len(self.annotations)
19
+
20
+ def __getitem__(self, idx):
21
+
22
+ ann = self.annotations[idx]
23
+ caption = random.choice(ann["captions"])
24
+
25
+ image_path = os.path.join(self.image_folder, ann["image"])
26
+ image = Image.open(image_path).convert("RGB")
27
+
28
+ # 🔥 IMPORTANT: 384px
29
+ image = image.resize((384, 384))
30
+
31
+ encoding = self.processor(
32
+ image,
33
+ caption,
34
+ padding="max_length",
35
+ truncation=True,
36
+ return_tensors="pt"
37
+ )
38
+
39
+ input_ids = encoding["input_ids"].squeeze(0)
40
+
41
+ return {
42
+ "pixel_values": encoding["pixel_values"].squeeze(0),
43
+ "input_ids": input_ids,
44
+ "attention_mask": encoding["attention_mask"].squeeze(0),
45
+ "labels": input_ids.clone()
46
+ }
dataset_advanced.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import random
4
+ import re
5
+ from torch.utils.data import Dataset
6
+ from PIL import Image
7
+
8
+
9
+ class COCODatasetAdvanced(Dataset):
10
+ def __init__(self,
11
+ annotation_path,
12
+ image_folder,
13
+ processor,
14
+ mode="mixed",
15
+ max_length=40):
16
+
17
+ self.image_folder = image_folder
18
+ self.processor = processor
19
+ self.max_length = max_length
20
+ self.mode = mode
21
+
22
+ with open(annotation_path, "r") as f:
23
+ raw_data = [json.loads(line) for line in f]
24
+
25
+ self.annotations = []
26
+
27
+ for ann in raw_data:
28
+
29
+ filtered_captions = []
30
+
31
+ for cap in ann["captions"]:
32
+
33
+ cap = cap.strip().lower()
34
+
35
+ # ---------- QUALITY FILTERS ----------
36
+
37
+ # Remove very short captions
38
+ if len(cap.split()) < 3:
39
+ continue
40
+
41
+ # Remove repeated words
42
+ words = cap.split()
43
+ if len(set(words)) < len(words) * 0.6:
44
+ continue
45
+
46
+ # Remove non-alphabetic captions
47
+ if not re.search(r"[a-z]", cap):
48
+ continue
49
+
50
+ word_count = len(words)
51
+
52
+ # ---------- LENGTH FILTERS ----------
53
+
54
+ if self.mode == "short" and word_count <= 8:
55
+ filtered_captions.append(cap)
56
+
57
+ elif self.mode == "long" and word_count > 15:
58
+ filtered_captions.append(cap)
59
+
60
+ elif self.mode == "mixed":
61
+ filtered_captions.append(cap)
62
+
63
+ if len(filtered_captions) > 0:
64
+ self.annotations.append({
65
+ "image": ann["image"],
66
+ "captions": filtered_captions
67
+ })
68
+
69
+ def __len__(self):
70
+ return len(self.annotations)
71
+
72
+ def __getitem__(self, idx):
73
+
74
+ ann = self.annotations[idx]
75
+ file_name = ann["image"]
76
+ caption = random.choice(ann["captions"])
77
+
78
+ image_path = os.path.join(self.image_folder, file_name)
79
+ image = Image.open(image_path).convert("RGB")
80
+
81
+ encoding = self.processor(
82
+ images=image,
83
+ text=caption,
84
+ padding="max_length",
85
+ truncation=True,
86
+ max_length=self.max_length,
87
+ return_tensors="pt"
88
+ )
89
+
90
+ input_ids = encoding["input_ids"].squeeze(0)
91
+
92
+ return {
93
+ "pixel_values": encoding["pixel_values"].squeeze(0),
94
+ "input_ids": input_ids,
95
+ "attention_mask": encoding["attention_mask"].squeeze(0),
96
+ "labels": input_ids.clone()
97
+ }
dataset_git.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import torch
4
+ from torch.utils.data import Dataset
5
+ from PIL import Image
6
+ import json
7
+
8
+
9
+ class COCODatasetGIT(Dataset):
10
+
11
+ def __init__(self, annotation_file, image_folder, processor, mode="mixed"):
12
+
13
+ self.annotations = []
14
+ self.image_folder = image_folder
15
+ self.processor = processor
16
+ self.mode = mode
17
+
18
+ # Proper JSONL loading
19
+ with open(annotation_file, "r") as f:
20
+ for line in f:
21
+ self.annotations.append(json.loads(line.strip()))
22
+
23
+ def __len__(self):
24
+ return len(self.annotations)
25
+
26
+ def select_caption(self, captions):
27
+
28
+ if self.mode == "short":
29
+ captions = [c for c in captions if len(c.split()) <= 10]
30
+
31
+ elif self.mode == "long":
32
+ captions = [c for c in captions if len(c.split()) > 10]
33
+
34
+ if len(captions) == 0:
35
+ captions = self.annotations[
36
+ random.randint(0, len(self.annotations) - 1)
37
+ ]["captions"]
38
+
39
+ return random.choice(captions)
40
+
41
+ def __getitem__(self, idx):
42
+
43
+ ann = self.annotations[idx]
44
+
45
+ image_path = os.path.join(self.image_folder, ann["image"])
46
+ image = Image.open(image_path).convert("RGB")
47
+
48
+ caption = self.select_caption(ann["captions"])
49
+
50
+ encoding = self.processor(
51
+ images=image,
52
+ text=caption,
53
+ padding="max_length",
54
+ truncation=True,
55
+ max_length=30,
56
+ return_tensors="pt"
57
+ )
58
+
59
+ input_ids = encoding["input_ids"].squeeze(0)
60
+ attention_mask = encoding["attention_mask"].squeeze(0)
61
+ pixel_values = encoding["pixel_values"].squeeze(0)
62
+
63
+ return {
64
+ "pixel_values": pixel_values,
65
+ "input_ids": input_ids,
66
+ "attention_mask": attention_mask,
67
+ "labels": input_ids # GIT uses input_ids as labels
68
+ }
dataset_vit_gpt2.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import random
4
+ from torch.utils.data import Dataset
5
+ from PIL import Image
6
+
7
+
8
+ class COCODatasetViTGPT2(Dataset):
9
+ def __init__(self,
10
+ annotation_path,
11
+ image_folder,
12
+ image_processor,
13
+ tokenizer,
14
+ mode="short",
15
+ max_length=20):
16
+
17
+ self.image_folder = image_folder
18
+ self.image_processor = image_processor
19
+ self.tokenizer = tokenizer
20
+ self.max_length = max_length
21
+ self.mode = mode
22
+
23
+ with open(annotation_path, "r") as f:
24
+ raw_data = [json.loads(line) for line in f]
25
+
26
+ self.annotations = []
27
+
28
+ for ann in raw_data:
29
+ filtered = []
30
+
31
+ for cap in ann["captions"]:
32
+ words = cap.split()
33
+ wc = len(words)
34
+
35
+ if mode == "short" and wc <= 8:
36
+ filtered.append(cap)
37
+ elif mode == "long" and wc > 15:
38
+ filtered.append(cap)
39
+ elif mode == "mixed":
40
+ filtered.append(cap)
41
+
42
+ if len(filtered) > 0:
43
+ self.annotations.append({
44
+ "image": ann["image"],
45
+ "captions": filtered
46
+ })
47
+
48
+ def __len__(self):
49
+ return len(self.annotations)
50
+
51
+ def __getitem__(self, idx):
52
+
53
+ ann = self.annotations[idx]
54
+ caption = random.choice(ann["captions"])
55
+
56
+ image_path = os.path.join(self.image_folder, ann["image"])
57
+ image = Image.open(image_path).convert("RGB")
58
+
59
+ pixel_values = self.image_processor(
60
+ images=image,
61
+ return_tensors="pt"
62
+ ).pixel_values.squeeze(0)
63
+
64
+ tokenized = self.tokenizer(
65
+ caption,
66
+ padding="max_length",
67
+ truncation=True,
68
+ max_length=self.max_length,
69
+ return_tensors="pt"
70
+ )
71
+
72
+ input_ids = tokenized.input_ids.squeeze(0)
73
+
74
+ return {
75
+ "pixel_values": pixel_values,
76
+ "labels": input_ids
77
+ }
evaluate.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from transformers import (
6
+ BlipProcessor,
7
+ BlipForConditionalGeneration,
8
+ AutoTokenizer,
9
+ AutoModelForSequenceClassification
10
+ )
11
+ from PIL import Image
12
+
13
+ # ---------------------------------------
14
+ # Load Models
15
+ # ---------------------------------------
16
+ def load_models():
17
+
18
+ device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
19
+
20
+ print("Using device:", device)
21
+
22
+ caption_model = BlipForConditionalGeneration.from_pretrained("saved_model_phase2")
23
+ caption_processor = BlipProcessor.from_pretrained("saved_model_phase2")
24
+
25
+ caption_model.to(device)
26
+ caption_model.eval()
27
+
28
+ # Toxicity model
29
+ tox_tokenizer = AutoTokenizer.from_pretrained("unitary/toxic-bert")
30
+ tox_model = AutoModelForSequenceClassification.from_pretrained("unitary/toxic-bert")
31
+
32
+ tox_model.to(device)
33
+ tox_model.eval()
34
+
35
+ return caption_model, caption_processor, tox_model, tox_tokenizer, device
36
+
37
+
38
+ # ---------------------------------------
39
+ # Generate Caption + Confidence
40
+ # ---------------------------------------
41
+ def generate_caption(model, processor, image, device):
42
+
43
+ inputs = processor(images=image, return_tensors="pt").to(device)
44
+
45
+ with torch.no_grad():
46
+ outputs = model.generate(
47
+ **inputs,
48
+ num_beams=5,
49
+ max_length=20,
50
+ length_penalty=1.0,
51
+ output_scores=True,
52
+ return_dict_in_generate=True
53
+ )
54
+
55
+ generated_ids = outputs.sequences
56
+ caption = processor.decode(
57
+ generated_ids[0],
58
+ skip_special_tokens=True
59
+ )
60
+
61
+ # True confidence
62
+ seq_score = outputs.sequences_scores[0]
63
+ confidence = torch.exp(seq_score).item()
64
+
65
+ return caption, confidence
66
+
67
+
68
+ # ---------------------------------------
69
+ # Toxicity Score
70
+ # ---------------------------------------
71
+ def check_toxicity(tox_model, tox_tokenizer, caption, device):
72
+
73
+ inputs = tox_tokenizer(
74
+ caption,
75
+ return_tensors="pt",
76
+ truncation=True
77
+ ).to(device)
78
+
79
+ with torch.no_grad():
80
+ outputs = tox_model(**inputs)
81
+ probs = F.softmax(outputs.logits, dim=-1)
82
+
83
+ toxic_score = probs[0][1].item()
84
+ return toxic_score
85
+
86
+
87
+ # ---------------------------------------
88
+ # Evaluate Single Image
89
+ # ---------------------------------------
90
+ def evaluate_image(image_path, models):
91
+
92
+ caption_model, caption_processor, tox_model, tox_tokenizer, device = models
93
+
94
+ image = Image.open(image_path).convert("RGB")
95
+
96
+ caption, confidence = generate_caption(
97
+ caption_model,
98
+ caption_processor,
99
+ image,
100
+ device
101
+ )
102
+
103
+ toxic_score = check_toxicity(
104
+ tox_model,
105
+ tox_tokenizer,
106
+ caption,
107
+ device
108
+ )
109
+
110
+ print("\n===================================")
111
+ print("Image:", image_path)
112
+ print("Caption:", caption)
113
+ print(f"Confidence: {confidence:.3f}")
114
+ print(f"Toxicity Score: {toxic_score:.3f}")
115
+
116
+ if toxic_score > 0.6:
117
+ print("⚠️ WARNING: Caption flagged as toxic")
118
+ print("===================================\n")
119
+
120
+
121
+ # ---------------------------------------
122
+ # Main
123
+ # ---------------------------------------
124
+ def main():
125
+
126
+ parser = argparse.ArgumentParser()
127
+ parser.add_argument("--image", type=str, help="Path to single image")
128
+ parser.add_argument("--folder", type=str, help="Path to folder of images")
129
+
130
+ args = parser.parse_args()
131
+
132
+ if not args.image and not args.folder:
133
+ print("Please provide --image or --folder")
134
+ return
135
+
136
+ models = load_models()
137
+
138
+ if args.image:
139
+ evaluate_image(args.image, models)
140
+
141
+ if args.folder:
142
+ for file in os.listdir(args.folder):
143
+ if file.lower().endswith((".jpg", ".jpeg", ".png")):
144
+ path = os.path.join(args.folder, file)
145
+ evaluate_image(path, models)
146
+
147
+
148
+ if __name__ == "__main__":
149
+ main()
plot/beam_experiment_plot.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+
3
+ # Beam sizes tested
4
+ beam_sizes = [1, 3, 5, 10]
5
+
6
+ # Example CIDEr scores from experiments
7
+ blip_scores = [0.52, 0.59, 0.61, 0.60]
8
+ vit_scores = [0.50, 0.56, 0.60, 0.58]
9
+ git_scores = [0.12, 0.16, 0.17, 0.16]
10
+
11
+ plt.figure(figsize=(8,5))
12
+
13
+ plt.plot(beam_sizes, blip_scores, marker='o', label="BLIP")
14
+ plt.plot(beam_sizes, vit_scores, marker='o', label="ViT-GPT2")
15
+ plt.plot(beam_sizes, git_scores, marker='o', label="GIT")
16
+
17
+ plt.xlabel("Beam Size")
18
+ plt.ylabel("CIDEr Score")
19
+ plt.title("Effect of Beam Size on Caption Quality")
20
+
21
+ plt.legend()
22
+
23
+ plt.grid(True)
24
+
25
+ plt.savefig("beam_search_experiment.png", dpi=300)
26
+
27
+ plt.show()
plot/caption_length_analysis.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import matplotlib.pyplot as plt
3
+ import numpy as np
4
+
5
+ ANNOTATION_FILE = "annotations/captions_validation.jsonl"
6
+
7
+ short = []
8
+ medium = []
9
+ long = []
10
+
11
+ with open(ANNOTATION_FILE) as f:
12
+ for line in f:
13
+ data = json.loads(line)
14
+
15
+ caption = data["captions"][0]
16
+ length = len(caption.split())
17
+
18
+ if length <= 8:
19
+ short.append(length)
20
+
21
+ elif length <= 15:
22
+ medium.append(length)
23
+
24
+ else:
25
+ long.append(length)
26
+
27
+ print("Short captions:", len(short))
28
+ print("Medium captions:", len(medium))
29
+ print("Long captions:", len(long))
30
+
31
+
32
+ # Example scores from your training logs
33
+ blip_scores = [0.71, 0.60, 0.48]
34
+ vit_scores = [0.65, 0.59, 0.42]
35
+ git_scores = [0.30, 0.18, 0.11]
36
+
37
+ labels = ["Short", "Medium", "Long"]
38
+
39
+ x = np.arange(len(labels))
40
+ width = 0.25
41
+
42
+ plt.figure(figsize=(9,5))
43
+
44
+ plt.bar(x - width, blip_scores, width, label="BLIP")
45
+ plt.bar(x, vit_scores, width, label="ViT-GPT2")
46
+ plt.bar(x + width, git_scores, width, label="GIT")
47
+
48
+ plt.xlabel("Caption Length")
49
+ plt.ylabel("CIDEr Score")
50
+ plt.title("Model Performance vs Caption Length")
51
+
52
+ plt.xticks(x, labels)
53
+
54
+ plt.legend()
55
+
56
+ plt.savefig("caption_length_analysis.png", dpi=300)
57
+
58
+ plt.show()
requirements.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ streamlit
2
+ torch
3
+ transformers
4
+ accelerate
5
+ pillow
6
+ numpy
7
+ matplotlib
8
+ pandas
9
+ torch
10
+ torchvision
11
+ transformers
12
+ datasets
13
+ Pillow
14
+ numpy
15
+ tqdm
16
+ pycocoevalcap
17
+ streamlit
18
+ matplotlib
19
+ pandas
20
+ scikit-learn
src/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ """
2
+ Top-level package for the image captioning project.
3
+
4
+ This package exposes the core modules used in training, evaluation,
5
+ and serving the captioning models.
6
+ """
7
+
src/data/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ """
2
+ Data loading utilities and dataset definitions.
3
+ """
4
+
src/data/coco_384_dataset.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import random
4
+ from typing import Any, Dict
5
+
6
+ from PIL import Image
7
+ from torch.utils.data import Dataset
8
+
9
+
10
+ class COCODataset384(Dataset):
11
+ """
12
+ COCO-style dataset that always resizes images to 384x384 and uses
13
+ a BLIP-style processor for joint image-text encoding.
14
+ """
15
+
16
+ def __init__(self, annotation_path: str, image_folder: str, processor: Any) -> None:
17
+ self.image_folder = image_folder
18
+ self.processor = processor
19
+
20
+ with open(annotation_path, "r") as f:
21
+ self.annotations = [json.loads(line) for line in f]
22
+
23
+ def __len__(self) -> int:
24
+ return len(self.annotations)
25
+
26
+ def __getitem__(self, idx: int) -> Dict[str, Any]:
27
+ ann = self.annotations[idx]
28
+ caption = random.choice(ann["captions"])
29
+
30
+ image_path = os.path.join(self.image_folder, ann["image"])
31
+ image = Image.open(image_path).convert("RGB")
32
+
33
+ # 384px resize for the vision backbone
34
+ image = image.resize((384, 384))
35
+
36
+ encoding = self.processor(
37
+ image,
38
+ caption,
39
+ padding="max_length",
40
+ truncation=True,
41
+ return_tensors="pt",
42
+ )
43
+
44
+ input_ids = encoding["input_ids"].squeeze(0)
45
+
46
+ return {
47
+ "pixel_values": encoding["pixel_values"].squeeze(0),
48
+ "input_ids": input_ids,
49
+ "attention_mask": encoding["attention_mask"].squeeze(0),
50
+ "labels": input_ids.clone(),
51
+ }
52
+
src/data/coco_advanced_dataset.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import random
4
+ import re
5
+ from typing import Any, Dict, List
6
+
7
+ from PIL import Image
8
+ from torch.utils.data import Dataset
9
+
10
+
11
+ class COCODatasetAdvanced(Dataset):
12
+ """
13
+ COCO dataset with caption quality and length filtering.
14
+ """
15
+
16
+ def __init__(
17
+ self,
18
+ annotation_path: str,
19
+ image_folder: str,
20
+ processor: Any,
21
+ mode: str = "mixed",
22
+ max_length: int = 40,
23
+ ) -> None:
24
+ self.image_folder = image_folder
25
+ self.processor = processor
26
+ self.max_length = max_length
27
+ self.mode = mode
28
+
29
+ with open(annotation_path, "r") as f:
30
+ raw_data = [json.loads(line) for line in f]
31
+
32
+ self.annotations: List[Dict[str, Any]] = []
33
+
34
+ for ann in raw_data:
35
+ filtered_captions: List[str] = []
36
+
37
+ for cap in ann["captions"]:
38
+ cap = cap.strip().lower()
39
+
40
+ # Remove very short captions
41
+ if len(cap.split()) < 3:
42
+ continue
43
+
44
+ # Remove repeated words
45
+ words = cap.split()
46
+ if len(set(words)) < len(words) * 0.6:
47
+ continue
48
+
49
+ # Remove non-alphabetic captions
50
+ if not re.search(r"[a-z]", cap):
51
+ continue
52
+
53
+ word_count = len(words)
54
+
55
+ if self.mode == "short" and word_count <= 8:
56
+ filtered_captions.append(cap)
57
+ elif self.mode == "long" and word_count > 15:
58
+ filtered_captions.append(cap)
59
+ elif self.mode == "mixed":
60
+ filtered_captions.append(cap)
61
+
62
+ if filtered_captions:
63
+ self.annotations.append(
64
+ {
65
+ "image": ann["image"],
66
+ "captions": filtered_captions,
67
+ }
68
+ )
69
+
70
+ def __len__(self) -> int:
71
+ return len(self.annotations)
72
+
73
+ def __getitem__(self, idx: int) -> Dict[str, Any]:
74
+ ann = self.annotations[idx]
75
+ file_name = ann["image"]
76
+ caption = random.choice(ann["captions"])
77
+
78
+ image_path = os.path.join(self.image_folder, file_name)
79
+ image = Image.open(image_path).convert("RGB")
80
+
81
+ encoding = self.processor(
82
+ images=image,
83
+ text=caption,
84
+ padding="max_length",
85
+ truncation=True,
86
+ max_length=self.max_length,
87
+ return_tensors="pt",
88
+ )
89
+
90
+ input_ids = encoding["input_ids"].squeeze(0)
91
+
92
+ return {
93
+ "pixel_values": encoding["pixel_values"].squeeze(0),
94
+ "input_ids": input_ids,
95
+ "attention_mask": encoding["attention_mask"].squeeze(0),
96
+ "labels": input_ids.clone(),
97
+ }
98
+
src/data/coco_vit_gpt2_dataset.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import random
4
+ from typing import Any, Dict, List
5
+
6
+ from PIL import Image
7
+ from torch.utils.data import Dataset
8
+
9
+
10
+ class COCODatasetViTGPT2(Dataset):
11
+ """
12
+ COCO dataset tailored for ViT + GPT-2 style architectures with
13
+ separate image processor and tokenizer.
14
+ """
15
+
16
+ def __init__(
17
+ self,
18
+ annotation_path: str,
19
+ image_folder: str,
20
+ image_processor: Any,
21
+ tokenizer: Any,
22
+ mode: str = "short",
23
+ max_length: int = 20,
24
+ ) -> None:
25
+ self.image_folder = image_folder
26
+ self.image_processor = image_processor
27
+ self.tokenizer = tokenizer
28
+ self.max_length = max_length
29
+ self.mode = mode
30
+
31
+ with open(annotation_path, "r") as f:
32
+ raw_data = [json.loads(line) for line in f]
33
+
34
+ self.annotations: List[Dict[str, Any]] = []
35
+
36
+ for ann in raw_data:
37
+ filtered: List[str] = []
38
+
39
+ for cap in ann["captions"]:
40
+ words = cap.split()
41
+ wc = len(words)
42
+
43
+ if mode == "short" and wc <= 8:
44
+ filtered.append(cap)
45
+ elif mode == "long" and wc > 15:
46
+ filtered.append(cap)
47
+ elif mode == "mixed":
48
+ filtered.append(cap)
49
+
50
+ if filtered:
51
+ self.annotations.append(
52
+ {
53
+ "image": ann["image"],
54
+ "captions": filtered,
55
+ }
56
+ )
57
+
58
+ def __len__(self) -> int:
59
+ return len(self.annotations)
60
+
61
+ def __getitem__(self, idx: int) -> Dict[str, Any]:
62
+ ann = self.annotations[idx]
63
+ caption = random.choice(ann["captions"])
64
+
65
+ image_path = os.path.join(self.image_folder, ann["image"])
66
+ image = Image.open(image_path).convert("RGB")
67
+
68
+ pixel_values = self.image_processor(
69
+ images=image,
70
+ return_tensors="pt",
71
+ ).pixel_values.squeeze(0)
72
+
73
+ tokenized = self.tokenizer(
74
+ caption,
75
+ padding="max_length",
76
+ truncation=True,
77
+ max_length=self.max_length,
78
+ return_tensors="pt",
79
+ )
80
+
81
+ input_ids = tokenized.input_ids.squeeze(0)
82
+
83
+ return {
84
+ "pixel_values": pixel_values,
85
+ "labels": input_ids,
86
+ }
87
+
src/evaluation/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ """
2
+ Evaluation utilities (e.g., CIDEr scoring).
3
+ """
4
+
src/evaluation/cider_eval.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Any
3
+
4
+ from PIL import Image
5
+ from pycocoevalcap.cider.cider import Cider
6
+ from tqdm import tqdm
7
+
8
+
9
+ def generate_caption(model: Any, processor: Any, image: Image.Image, device) -> str:
10
+ """
11
+ Run the captioning model on a single image and return the decoded caption.
12
+ """
13
+ inputs = processor(images=image, return_tensors="pt").to(device)
14
+
15
+ with getattr(__import__("torch"), "no_grad")():
16
+ torch = __import__("torch")
17
+ generated_ids = model.generate(
18
+ **inputs,
19
+ max_length=30,
20
+ num_beams=5,
21
+ )
22
+
23
+ caption = processor.decode(
24
+ generated_ids[0],
25
+ skip_special_tokens=True,
26
+ )
27
+ return caption
28
+
29
+
30
+ def evaluate_cider(model: Any, processor: Any, val_dataset, device, max_samples: int = 200) -> float:
31
+ """
32
+ Compute CIDEr score on a validation subset.
33
+
34
+ Expects a PyTorch `Subset`/`Dataset` where:
35
+ - `val_dataset.indices[idx]` gives the underlying index
36
+ - `val_dataset.dataset.annotations[...]` is a list of dicts with
37
+ keys `image` and `captions`.
38
+ """
39
+ import torch # local import to avoid hard dependency for non-training paths
40
+
41
+ model.eval()
42
+
43
+ cider_scorer = Cider()
44
+ ground_truth = {}
45
+ predictions = {}
46
+
47
+ for idx in tqdm(range(min(max_samples, len(val_dataset))), desc="CIDEr Eval"):
48
+ real_idx = val_dataset.indices[idx]
49
+ ann = val_dataset.dataset.annotations[real_idx]
50
+
51
+ image_path = os.path.join("train2017", ann["image"])
52
+ image = Image.open(image_path).convert("RGB")
53
+
54
+ pred_caption = generate_caption(model, processor, image, device)
55
+
56
+ ground_truth[idx] = ann["captions"]
57
+ predictions[idx] = [pred_caption]
58
+
59
+ score, _ = cider_scorer.compute_score(ground_truth, predictions)
60
+
61
+ print(f"CIDEr Score: {score:.4f}")
62
+
63
+ model.train()
64
+ return score
65
+
src/streamlit_app.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import altair as alt
2
+ import numpy as np
3
+ import pandas as pd
4
+ import streamlit as st
5
+
6
+ """
7
+ # Welcome to Streamlit!
8
+
9
+ Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
+ If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
+ forums](https://discuss.streamlit.io).
12
+
13
+ In the meantime, below is an example of what you can do with just a few lines of code:
14
+ """
15
+
16
+ num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
+ num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
+
19
+ indices = np.linspace(0, 1, num_points)
20
+ theta = 2 * np.pi * num_turns * indices
21
+ radius = indices
22
+
23
+ x = radius * np.cos(theta)
24
+ y = radius * np.sin(theta)
25
+
26
+ df = pd.DataFrame({
27
+ "x": x,
28
+ "y": y,
29
+ "idx": indices,
30
+ "rand": np.random.randn(num_points),
31
+ })
32
+
33
+ st.altair_chart(alt.Chart(df, height=700, width=700)
34
+ .mark_point(filled=True)
35
+ .encode(
36
+ x=alt.X("x", axis=None),
37
+ y=alt.Y("y", axis=None),
38
+ color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
+ size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
+ ))
src/training/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ """
2
+ Training entrypoints and training utilities.
3
+ """
4
+
src/training/train_phase1.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ from PIL import Image
5
+ from torch.optim import AdamW
6
+ from torch.optim.lr_scheduler import CosineAnnealingLR
7
+ from torch.utils.data import DataLoader, random_split
8
+ from transformers import BlipForConditionalGeneration, BlipProcessor
9
+ from tqdm import tqdm
10
+
11
+ from src.data.coco_384_dataset import COCODataset384 as COCODataset
12
+ from src.evaluation.cider_eval import evaluate_cider
13
+
14
+
15
+ def main() -> None:
16
+ if not torch.backends.mps.is_available():
17
+ raise RuntimeError("MPS not available.")
18
+
19
+ device = torch.device("mps")
20
+ print("Using device:", device)
21
+
22
+ # =========================
23
+ # CONFIG
24
+ # =========================
25
+ EPOCHS = 5
26
+ BATCH_SIZE = 6
27
+ LR = 3e-5
28
+ NUM_WORKERS = 0
29
+ FINAL_MODEL_DIR = "saved_model_phase1"
30
+
31
+ os.makedirs(FINAL_MODEL_DIR, exist_ok=True)
32
+
33
+ # =========================
34
+ # LOAD MODEL
35
+ # =========================
36
+ processor = BlipProcessor.from_pretrained(
37
+ "Salesforce/blip-image-captioning-base"
38
+ )
39
+
40
+ model = BlipForConditionalGeneration.from_pretrained(
41
+ "Salesforce/blip-image-captioning-base"
42
+ )
43
+
44
+ # Unfreeze LAST 2 vision layers only
45
+ for name, param in model.vision_model.named_parameters():
46
+ if "encoder.layers.10" in name or "encoder.layers.11" in name:
47
+ param.requires_grad = True
48
+ else:
49
+ param.requires_grad = False
50
+
51
+ model.gradient_checkpointing_enable()
52
+ model.config.use_cache = False
53
+ model.to(device)
54
+
55
+ # =========================
56
+ # DATASET SPLIT
57
+ # =========================
58
+ full_dataset = COCODataset(
59
+ "annotations/subset_10k.jsonl",
60
+ "train2017",
61
+ processor,
62
+ )
63
+
64
+ train_size = int(0.9 * len(full_dataset))
65
+ val_size = len(full_dataset) - train_size
66
+
67
+ train_dataset, val_dataset = random_split(
68
+ full_dataset,
69
+ [train_size, val_size],
70
+ )
71
+
72
+ train_loader = DataLoader(
73
+ train_dataset,
74
+ batch_size=BATCH_SIZE,
75
+ shuffle=True,
76
+ num_workers=NUM_WORKERS,
77
+ )
78
+
79
+ val_loader = DataLoader(
80
+ val_dataset,
81
+ batch_size=BATCH_SIZE,
82
+ shuffle=False,
83
+ num_workers=NUM_WORKERS,
84
+ )
85
+
86
+ optimizer = AdamW(
87
+ filter(lambda p: p.requires_grad, model.parameters()),
88
+ lr=LR,
89
+ )
90
+
91
+ scheduler = CosineAnnealingLR(optimizer, T_max=EPOCHS)
92
+
93
+ # =========================
94
+ # EARLY STOPPING
95
+ # =========================
96
+ best_cider = 0.0
97
+ patience = 3
98
+ counter = 0
99
+
100
+ # =========================
101
+ # TRAIN LOOP
102
+ # =========================
103
+ for epoch in range(EPOCHS):
104
+ model.train()
105
+ total_loss = 0.0
106
+ progress_bar = tqdm(train_loader, desc=f"Epoch {epoch + 1}")
107
+
108
+ for batch in progress_bar:
109
+ batch = {k: v.to(device) for k, v in batch.items()}
110
+
111
+ with torch.autocast(device_type="mps", dtype=torch.float16):
112
+ outputs = model(**batch)
113
+ loss = outputs.loss
114
+
115
+ loss.backward()
116
+ optimizer.step()
117
+ optimizer.zero_grad()
118
+
119
+ total_loss += loss.item()
120
+ progress_bar.set_postfix(loss=loss.item())
121
+
122
+ avg_train_loss = total_loss / len(train_loader)
123
+ print(f"Epoch {epoch + 1} Train Loss: {avg_train_loss:.4f}")
124
+
125
+ # =========================
126
+ # VALIDATION LOSS
127
+ # =========================
128
+ model.eval()
129
+ val_loss = 0.0
130
+
131
+ with torch.no_grad():
132
+ for batch in val_loader:
133
+ batch = {k: v.to(device) for k, v in batch.items()}
134
+ outputs = model(**batch)
135
+ val_loss += outputs.loss.item()
136
+
137
+ val_loss /= len(val_loader)
138
+ print(f"Epoch {epoch + 1} Validation Loss: {val_loss:.4f}")
139
+
140
+ # =========================
141
+ # CIDEr
142
+ # =========================
143
+ cider_score = evaluate_cider(model, processor, val_dataset, device)
144
+
145
+ # =========================
146
+ # SAVE BEST CIDEr MODEL
147
+ # =========================
148
+ if cider_score > best_cider:
149
+ best_cider = cider_score
150
+ counter = 0
151
+ model.save_pretrained(FINAL_MODEL_DIR)
152
+ processor.save_pretrained(FINAL_MODEL_DIR)
153
+ print("Best CIDEr model saved.")
154
+ else:
155
+ counter += 1
156
+
157
+ if counter >= patience:
158
+ print("Early stopping triggered.")
159
+ break
160
+
161
+ scheduler.step()
162
+
163
+ print("Phase 1 training complete.")
164
+
165
+
166
+ if __name__ == "__main__":
167
+ main()
168
+
src/training/train_phase2.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ from torch.optim import AdamW
5
+ from torch.optim.lr_scheduler import CosineAnnealingLR
6
+ from torch.utils.data import DataLoader, random_split
7
+ from transformers import BlipForConditionalGeneration, BlipProcessor
8
+ from tqdm import tqdm
9
+
10
+ from src.data.coco_advanced_dataset import COCODatasetAdvanced
11
+ from src.evaluation.cider_eval import evaluate_cider
12
+
13
+
14
+ def main() -> None:
15
+ if not torch.backends.mps.is_available():
16
+ raise RuntimeError("MPS not available.")
17
+
18
+ device = torch.device("mps")
19
+ print("Using device:", device)
20
+
21
+ # =========================
22
+ # CONFIG
23
+ # =========================
24
+ EPOCHS = 5
25
+ BATCH_SIZE = 6
26
+ LR = 3e-5 # Lower LR for partial unfreezing
27
+ NUM_WORKERS = 0
28
+ FINAL_MODEL_DIR = "saved_model_phase2"
29
+
30
+ os.makedirs(FINAL_MODEL_DIR, exist_ok=True)
31
+
32
+ # =========================
33
+ # LOAD MODEL
34
+ # =========================
35
+ processor = BlipProcessor.from_pretrained(
36
+ "Salesforce/blip-image-captioning-base"
37
+ )
38
+
39
+ model = BlipForConditionalGeneration.from_pretrained(
40
+ "Salesforce/blip-image-captioning-base"
41
+ )
42
+
43
+ # Unfreeze LAST 2 vision layers only
44
+ for name, param in model.vision_model.named_parameters():
45
+ if "encoder.layers.10" in name or "encoder.layers.11" in name:
46
+ param.requires_grad = True
47
+ else:
48
+ param.requires_grad = False
49
+
50
+ model.gradient_checkpointing_enable()
51
+ model.config.use_cache = False
52
+ model.to(device)
53
+
54
+ # =========================
55
+ # DATASET SPLIT
56
+ # =========================
57
+ MODE = "long" # change to "short" or "mixed" as needed
58
+
59
+ full_dataset = COCODatasetAdvanced(
60
+ "annotations/subset_10k.jsonl",
61
+ "train2017",
62
+ processor,
63
+ mode=MODE,
64
+ )
65
+
66
+ train_size = int(0.9 * len(full_dataset))
67
+ val_size = len(full_dataset) - train_size
68
+
69
+ train_dataset, val_dataset = random_split(
70
+ full_dataset,
71
+ [train_size, val_size],
72
+ )
73
+
74
+ train_loader = DataLoader(
75
+ train_dataset,
76
+ batch_size=BATCH_SIZE,
77
+ shuffle=True,
78
+ num_workers=NUM_WORKERS,
79
+ )
80
+
81
+ val_loader = DataLoader(
82
+ val_dataset,
83
+ batch_size=BATCH_SIZE,
84
+ shuffle=False,
85
+ num_workers=NUM_WORKERS,
86
+ )
87
+
88
+ optimizer = AdamW(
89
+ filter(lambda p: p.requires_grad, model.parameters()),
90
+ lr=LR,
91
+ )
92
+
93
+ scheduler = CosineAnnealingLR(optimizer, T_max=EPOCHS)
94
+
95
+ # =========================
96
+ # EARLY STOPPING
97
+ # =========================
98
+ best_cider = 0.0
99
+ patience = 3
100
+ counter = 0
101
+
102
+ # =========================
103
+ # TRAIN LOOP
104
+ # =========================
105
+ for epoch in range(EPOCHS):
106
+ model.train()
107
+ total_loss = 0.0
108
+ progress_bar = tqdm(train_loader, desc=f"Epoch {epoch + 1}")
109
+
110
+ for batch in progress_bar:
111
+ batch = {k: v.to(device) for k, v in batch.items()}
112
+
113
+ with torch.autocast(device_type="mps", dtype=torch.float16):
114
+ outputs = model(**batch)
115
+ loss = outputs.loss
116
+
117
+ loss.backward()
118
+ optimizer.step()
119
+ optimizer.zero_grad()
120
+
121
+ total_loss += loss.item()
122
+ progress_bar.set_postfix(loss=loss.item())
123
+
124
+ avg_train_loss = total_loss / len(train_loader)
125
+ print(f"Epoch {epoch + 1} Train Loss: {avg_train_loss:.4f}")
126
+
127
+ # =========================
128
+ # VALIDATION LOSS
129
+ # =========================
130
+ model.eval()
131
+ val_loss = 0.0
132
+
133
+ with torch.no_grad():
134
+ for batch in val_loader:
135
+ batch = {k: v.to(device) for k, v in batch.items()}
136
+ outputs = model(**batch)
137
+ val_loss += outputs.loss.item()
138
+
139
+ val_loss /= len(val_loader)
140
+ print(f"Epoch {epoch + 1} Validation Loss: {val_loss:.4f}")
141
+
142
+ # =========================
143
+ # CIDEr
144
+ # =========================
145
+ cider_score = evaluate_cider(model, processor, val_dataset, device)
146
+
147
+ # =========================
148
+ # SAVE BEST CIDEr MODEL
149
+ # =========================
150
+ if cider_score > best_cider:
151
+ best_cider = cider_score
152
+ counter = 0
153
+ model.save_pretrained(FINAL_MODEL_DIR)
154
+ processor.save_pretrained(FINAL_MODEL_DIR)
155
+ print("Best CIDEr model saved.")
156
+ else:
157
+ counter += 1
158
+
159
+ if counter >= patience:
160
+ print("Early stopping triggered.")
161
+ break
162
+
163
+ scheduler.step()
164
+
165
+ print("Phase 2 training complete.")
166
+
167
+
168
+ if __name__ == "__main__":
169
+ main()
170
+
src/utils/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ """
2
+ General-purpose utility functions and scripts.
3
+ """
4
+
src/utils/data_subset.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import random
3
+ from pathlib import Path
4
+ from typing import Iterable
5
+
6
+
7
+ def create_subset(
8
+ input_path: str | Path,
9
+ output_path: str | Path,
10
+ size: int = 20_000,
11
+ ) -> None:
12
+ """
13
+ Create a random subset of a JSONL annotations file.
14
+ """
15
+ input_path = Path(input_path)
16
+ output_path = Path(output_path)
17
+
18
+ with input_path.open("r") as f:
19
+ data = [json.loads(line) for line in f]
20
+
21
+ if size > len(data):
22
+ raise ValueError(f"Requested subset size {size} exceeds dataset size {len(data)}")
23
+
24
+ subset = random.sample(data, size)
25
+
26
+ output_path.parent.mkdir(parents=True, exist_ok=True)
27
+
28
+ with output_path.open("w") as f:
29
+ for item in subset:
30
+ f.write(json.dumps(item) + "\n")
31
+
32
+
33
+ def _main_from_cli(args: Iterable[str] | None = None) -> None:
34
+ """
35
+ Simple CLI wrapper when this module is executed as a script.
36
+ """
37
+ import argparse
38
+
39
+ parser = argparse.ArgumentParser(description="Create a random JSONL subset.")
40
+ parser.add_argument(
41
+ "--input",
42
+ default="annotations/captions_train.jsonl",
43
+ help="Input JSONL annotations path.",
44
+ )
45
+ parser.add_argument(
46
+ "--output",
47
+ default="annotations/subset_20k.jsonl",
48
+ help="Output JSONL path.",
49
+ )
50
+ parser.add_argument(
51
+ "--size",
52
+ type=int,
53
+ default=20_000,
54
+ help="Number of samples to keep.",
55
+ )
56
+
57
+ parsed = parser.parse_args(list(args) if args is not None else None)
58
+ create_subset(parsed.input, parsed.output, parsed.size)
59
+ print(f"Subset of {parsed.size} entries written to {parsed.output}")
60
+
61
+
62
+ if __name__ == "__main__":
63
+ _main_from_cli()
64
+
train_blip_20k_384.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from torch.utils.data import DataLoader, random_split
4
+ from transformers import BlipProcessor, BlipForConditionalGeneration
5
+ from torch.optim import AdamW
6
+ from torch.optim.lr_scheduler import CosineAnnealingLR
7
+ from dataset_384 import COCODataset384
8
+ from tqdm import tqdm
9
+
10
+
11
+ def main():
12
+
13
+ device = torch.device("mps")
14
+ print("Using device:", device)
15
+
16
+ EPOCHS = 5
17
+ BATCH_SIZE = 3 # ⚠️ Lower because 384px uses more memory
18
+ LR = 3e-5
19
+
20
+ CHECKPOINT_DIR = "checkpoints_20k_384"
21
+ MODEL_DIR = "saved_model_20k_384"
22
+
23
+ os.makedirs(CHECKPOINT_DIR, exist_ok=True)
24
+ os.makedirs(MODEL_DIR, exist_ok=True)
25
+
26
+ processor = BlipProcessor.from_pretrained(
27
+ "Salesforce/blip-image-captioning-base"
28
+ )
29
+
30
+ model = BlipForConditionalGeneration.from_pretrained(
31
+ "Salesforce/blip-image-captioning-base"
32
+ )
33
+
34
+ model.gradient_checkpointing_enable()
35
+ model.config.use_cache = False
36
+ model.to(device)
37
+
38
+ dataset = COCODataset384(
39
+ "annotations/subset_20k.jsonl",
40
+ "train2017",
41
+ processor
42
+ )
43
+
44
+ train_size = int(0.9 * len(dataset))
45
+ val_size = len(dataset) - train_size
46
+ train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
47
+
48
+ train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
49
+ val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)
50
+
51
+ optimizer = AdamW(model.parameters(), lr=LR)
52
+ scheduler = CosineAnnealingLR(optimizer, T_max=EPOCHS)
53
+
54
+ best_val_loss = float("inf")
55
+
56
+ for epoch in range(EPOCHS):
57
+
58
+ model.train()
59
+ total_loss = 0
60
+
61
+ for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
62
+
63
+ batch = {k: v.to(device) for k, v in batch.items()}
64
+
65
+ with torch.autocast(device_type="mps", dtype=torch.float16):
66
+ outputs = model(**batch)
67
+ loss = outputs.loss
68
+
69
+ loss.backward()
70
+ optimizer.step()
71
+ optimizer.zero_grad()
72
+
73
+ total_loss += loss.item()
74
+
75
+ train_loss = total_loss / len(train_loader)
76
+ print(f"Train Loss: {train_loss:.4f}")
77
+
78
+ # Validation
79
+ model.eval()
80
+ val_loss = 0
81
+
82
+ with torch.no_grad():
83
+ for batch in val_loader:
84
+ batch = {k: v.to(device) for k, v in batch.items()}
85
+ outputs = model(**batch)
86
+ val_loss += outputs.loss.item()
87
+
88
+ val_loss /= len(val_loader)
89
+ print(f"Validation Loss: {val_loss:.4f}")
90
+
91
+ if val_loss < best_val_loss:
92
+ best_val_loss = val_loss
93
+ model.save_pretrained(MODEL_DIR)
94
+ processor.save_pretrained(MODEL_DIR)
95
+ print("Best model saved.")
96
+
97
+ scheduler.step()
98
+
99
+ print("Training complete.")
100
+
101
+
102
+ if __name__ == "__main__":
103
+ main()
train_data_experiments.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from platform import processor
3
+ import torch
4
+ from torch.utils.data import DataLoader, random_split
5
+ from transformers import BlipProcessor, BlipForConditionalGeneration
6
+ from torch.optim import AdamW
7
+ from torch.optim.lr_scheduler import CosineAnnealingLR
8
+ from dataset_advanced import COCODataset
9
+ from tqdm import tqdm
10
+ from PIL import Image
11
+ from pycocoevalcap.cider.cider import Cider
12
+ from dataset_advanced import COCODatasetAdvanced
13
+
14
+
15
+ # =========================
16
+ # GENERATE CAPTION
17
+ # =========================
18
+ def generate_caption(model, processor, image, device):
19
+ inputs = processor(images=image, return_tensors="pt").to(device)
20
+
21
+ with torch.no_grad():
22
+ generated_ids = model.generate(
23
+ **inputs,
24
+ max_length=30,
25
+ num_beams=5
26
+ )
27
+
28
+ caption = processor.decode(
29
+ generated_ids[0],
30
+ skip_special_tokens=True
31
+ )
32
+ return caption
33
+
34
+
35
+ # =========================
36
+ # CIDEr EVALUATION
37
+ # =========================
38
+ def evaluate_cider(model, processor, val_dataset, device, max_samples=200):
39
+ model.eval()
40
+
41
+ cider_scorer = Cider()
42
+ ground_truth = {}
43
+ predictions = {}
44
+
45
+ for idx in tqdm(range(min(max_samples, len(val_dataset))), desc="CIDEr Eval"):
46
+ real_idx = val_dataset.indices[idx]
47
+ ann = val_dataset.dataset.annotations[real_idx]
48
+
49
+ image_path = os.path.join("train2017", ann["image"])
50
+ image = Image.open(image_path).convert("RGB")
51
+
52
+ pred_caption = generate_caption(model, processor, image, device)
53
+
54
+ ground_truth[idx] = ann["captions"]
55
+ predictions[idx] = [pred_caption]
56
+
57
+ score, _ = cider_scorer.compute_score(ground_truth, predictions)
58
+
59
+ print(f"CIDEr Score: {score:.4f}")
60
+
61
+ model.train()
62
+ return score
63
+
64
+
65
+ # =========================
66
+ # MAIN
67
+ # =========================
68
+ def main():
69
+
70
+ if not torch.backends.mps.is_available():
71
+ raise RuntimeError("MPS not available.")
72
+
73
+ device = torch.device("mps")
74
+ print("Using device:", device)
75
+
76
+ # =========================
77
+ # CONFIG
78
+ # =========================
79
+ EPOCHS = 5
80
+ BATCH_SIZE = 6
81
+ LR = 3e-5 # Lower LR for partial unfreezing
82
+ NUM_WORKERS = 0
83
+ FINAL_MODEL_DIR = "saved_model_phase2"
84
+
85
+ os.makedirs(FINAL_MODEL_DIR, exist_ok=True)
86
+
87
+ # =========================
88
+ # LOAD MODEL
89
+ # =========================
90
+ processor = BlipProcessor.from_pretrained(
91
+ "Salesforce/blip-image-captioning-base"
92
+ )
93
+
94
+ model = BlipForConditionalGeneration.from_pretrained(
95
+ "Salesforce/blip-image-captioning-base"
96
+ )
97
+
98
+ # 🔥 Unfreeze LAST 2 vision layers only
99
+ for name, param in model.vision_model.named_parameters():
100
+ if "encoder.layers.10" in name or "encoder.layers.11" in name:
101
+ param.requires_grad = True
102
+ else:
103
+ param.requires_grad = False
104
+
105
+ model.gradient_checkpointing_enable()
106
+ model.config.use_cache = False
107
+ model.to(device)
108
+
109
+ # =========================
110
+ # DATASET SPLIT
111
+ # =========================
112
+ MODE = "long" # change to "short" or "long"
113
+
114
+ full_dataset = COCODatasetAdvanced(
115
+ "annotations/subset_10k.jsonl",
116
+ "train2017",
117
+ processor,
118
+ mode=MODE
119
+ )
120
+
121
+ train_size = int(0.9 * len(full_dataset))
122
+ val_size = len(full_dataset) - train_size
123
+
124
+ train_dataset, val_dataset = random_split(
125
+ full_dataset,
126
+ [train_size, val_size]
127
+ )
128
+
129
+ train_loader = DataLoader(
130
+ train_dataset,
131
+ batch_size=BATCH_SIZE,
132
+ shuffle=True,
133
+ num_workers=NUM_WORKERS
134
+ )
135
+
136
+ val_loader = DataLoader(
137
+ val_dataset,
138
+ batch_size=BATCH_SIZE,
139
+ shuffle=False,
140
+ num_workers=NUM_WORKERS
141
+ )
142
+
143
+ optimizer = AdamW(
144
+ filter(lambda p: p.requires_grad, model.parameters()),
145
+ lr=LR
146
+ )
147
+
148
+ scheduler = CosineAnnealingLR(optimizer, T_max=EPOCHS)
149
+
150
+ # =========================
151
+ # EARLY STOPPING
152
+ # =========================
153
+ best_cider = 0
154
+ patience = 3
155
+ counter = 0
156
+
157
+ # =========================
158
+ # TRAIN LOOP
159
+ # =========================
160
+ for epoch in range(EPOCHS):
161
+
162
+ model.train()
163
+ total_loss = 0
164
+ progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}")
165
+
166
+ for batch in progress_bar:
167
+ batch = {k: v.to(device) for k, v in batch.items()}
168
+
169
+ with torch.autocast(device_type="mps", dtype=torch.float16):
170
+ outputs = model(**batch)
171
+ loss = outputs.loss
172
+
173
+ loss.backward()
174
+ optimizer.step()
175
+ optimizer.zero_grad()
176
+
177
+ total_loss += loss.item()
178
+ progress_bar.set_postfix(loss=loss.item())
179
+
180
+ avg_train_loss = total_loss / len(train_loader)
181
+ print(f"Epoch {epoch+1} Train Loss: {avg_train_loss:.4f}")
182
+
183
+ # =========================
184
+ # VALIDATION LOSS
185
+ # =========================
186
+ model.eval()
187
+ val_loss = 0
188
+
189
+ with torch.no_grad():
190
+ for batch in val_loader:
191
+ batch = {k: v.to(device) for k, v in batch.items()}
192
+ outputs = model(**batch)
193
+ val_loss += outputs.loss.item()
194
+
195
+ val_loss /= len(val_loader)
196
+ print(f"Epoch {epoch+1} Validation Loss: {val_loss:.4f}")
197
+
198
+ # =========================
199
+ # CIDEr
200
+ # =========================
201
+ cider_score = evaluate_cider(model, processor, val_dataset, device)
202
+
203
+ # =========================
204
+ # SAVE BEST CIDEr MODEL
205
+ # =========================
206
+ if cider_score > best_cider:
207
+ best_cider = cider_score
208
+ counter = 0
209
+ model.save_pretrained(FINAL_MODEL_DIR)
210
+ processor.save_pretrained(FINAL_MODEL_DIR)
211
+ print("Best CIDEr model saved.")
212
+ else:
213
+ counter += 1
214
+
215
+ if counter >= patience:
216
+ print("Early stopping triggered.")
217
+ break
218
+
219
+ scheduler.step()
220
+
221
+ print("Phase 2 training complete.")
222
+
223
+
224
+ if __name__ == "__main__":
225
+ main()
train_git.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from torch.utils.data import DataLoader, random_split
4
+ from transformers import GitProcessor, GitForCausalLM
5
+ from torch.optim import AdamW
6
+ from torch.optim.lr_scheduler import CosineAnnealingLR
7
+ from dataset_git import COCODatasetGIT
8
+ from tqdm import tqdm
9
+ from pycocoevalcap.cider.cider import Cider
10
+ from PIL import Image
11
+
12
+
13
+ def generate_caption(model, processor, image, device):
14
+
15
+ inputs = processor(images=image, return_tensors="pt").to(device)
16
+
17
+ with torch.no_grad():
18
+ output_ids = model.generate(
19
+ **inputs,
20
+ num_beams=5,
21
+ max_length=20
22
+ )
23
+
24
+ return processor.batch_decode(output_ids, skip_special_tokens=True)[0]
25
+
26
+
27
+ def evaluate_cider(model, processor, val_dataset, device, max_samples=200):
28
+
29
+ model.eval()
30
+ cider_scorer = Cider()
31
+
32
+ ground_truth = {}
33
+ predictions = {}
34
+
35
+ for idx in tqdm(range(min(max_samples, len(val_dataset))), desc="CIDEr Eval"):
36
+
37
+ real_idx = val_dataset.indices[idx]
38
+ ann = val_dataset.dataset.annotations[real_idx]
39
+
40
+ image_path = os.path.join("train2017", ann["image"])
41
+ image = Image.open(image_path).convert("RGB")
42
+
43
+ pred_caption = generate_caption(model, processor, image, device)
44
+
45
+ ground_truth[idx] = ann["captions"]
46
+ predictions[idx] = [pred_caption]
47
+
48
+ score, _ = cider_scorer.compute_score(ground_truth, predictions)
49
+
50
+ print(f"CIDEr Score: {score:.4f}")
51
+ model.train()
52
+ return score
53
+
54
+
55
+ def main():
56
+
57
+ device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
58
+ print("Using device:", device)
59
+
60
+ EPOCHS = 20
61
+ BATCH_SIZE = 4
62
+ LR = 5e-5
63
+ SAVE_DIR = "saved_git_model"
64
+
65
+ os.makedirs(SAVE_DIR, exist_ok=True)
66
+
67
+ processor = GitProcessor.from_pretrained("microsoft/git-base")
68
+ model = GitForCausalLM.from_pretrained("microsoft/git-base")
69
+
70
+ model.to(device)
71
+
72
+ dataset = COCODatasetGIT(
73
+ "annotations/subset_20k.jsonl",
74
+ "train2017",
75
+ processor,
76
+ mode="mixed"
77
+ )
78
+
79
+ train_size = int(0.9 * len(dataset))
80
+ val_size = len(dataset) - train_size
81
+
82
+ train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
83
+
84
+ train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
85
+
86
+ optimizer = AdamW(model.parameters(), lr=LR)
87
+ scheduler = CosineAnnealingLR(optimizer, T_max=EPOCHS)
88
+
89
+ best_cider = 0
90
+
91
+ for epoch in range(EPOCHS):
92
+
93
+ model.train()
94
+ total_loss = 0
95
+
96
+ for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
97
+
98
+ batch = {k: v.to(device) for k, v in batch.items()}
99
+
100
+ outputs = model(**batch)
101
+ loss = outputs.loss
102
+
103
+ loss.backward()
104
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
105
+
106
+ optimizer.step()
107
+ optimizer.zero_grad()
108
+
109
+ total_loss += loss.item()
110
+
111
+ print(f"Train Loss: {total_loss / len(train_loader):.4f}")
112
+
113
+ cider_score = evaluate_cider(model, processor, val_dataset, device)
114
+
115
+ if cider_score > best_cider:
116
+ best_cider = cider_score
117
+ model.save_pretrained(SAVE_DIR)
118
+ processor.save_pretrained(SAVE_DIR)
119
+ print("Best GIT model saved.")
120
+
121
+ scheduler.step()
122
+
123
+ print("GIT Training complete.")
124
+
125
+
126
+ if __name__ == "__main__":
127
+ main()
train_phase2.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from torch.utils.data import DataLoader, random_split
4
+ from transformers import BlipProcessor, BlipForConditionalGeneration
5
+ from torch.optim import AdamW
6
+ from torch.optim.lr_scheduler import CosineAnnealingLR
7
+ from dataset_advanced import COCODataset
8
+ from tqdm import tqdm
9
+ from PIL import Image
10
+ from pycocoevalcap.cider.cider import Cider
11
+
12
+
13
+ # =========================
14
+ # GENERATE CAPTION
15
+ # =========================
16
+ def generate_caption(model, processor, image, device):
17
+ inputs = processor(images=image, return_tensors="pt").to(device)
18
+
19
+ with torch.no_grad():
20
+ generated_ids = model.generate(
21
+ **inputs,
22
+ max_length=30,
23
+ num_beams=5
24
+ )
25
+
26
+ caption = processor.decode(
27
+ generated_ids[0],
28
+ skip_special_tokens=True
29
+ )
30
+ return caption
31
+
32
+ # =========================
33
+ # CIDEr EVALUATION
34
+ # =========================
35
+ def evaluate_cider(model, processor, val_dataset, device, max_samples=200):
36
+ model.eval()
37
+
38
+ cider_scorer = Cider()
39
+ ground_truth = {}
40
+ predictions = {}
41
+
42
+ for idx in tqdm(range(min(max_samples, len(val_dataset))), desc="CIDEr Eval"):
43
+ real_idx = val_dataset.indices[idx]
44
+ ann = val_dataset.dataset.annotations[real_idx]
45
+
46
+ image_path = os.path.join("train2017", ann["image"])
47
+ image = Image.open(image_path).convert("RGB")
48
+
49
+ pred_caption = generate_caption(model, processor, image, device)
50
+
51
+ ground_truth[idx] = ann["captions"]
52
+ predictions[idx] = [pred_caption]
53
+
54
+ score, _ = cider_scorer.compute_score(ground_truth, predictions)
55
+
56
+ print(f"CIDEr Score: {score:.4 f}")
57
+
58
+ model.train()
59
+ return score
60
+
61
+
62
+ # =========================
63
+ # MAIN
64
+ # =========================
65
+ def main():
66
+
67
+ if not torch.backends.mps.is_available():
68
+ raise RuntimeError("MPS not available.")
69
+
70
+ device = torch.device("mps")
71
+ print("Using device:", device)
72
+
73
+ # =========================
74
+ # CONFIG
75
+ # =========================
76
+ EPOCHS = 5
77
+ BATCH_SIZE = 6
78
+ LR = 3e-5 # Lower LR for partial unfreezing
79
+ NUM_WORKERS = 0
80
+ FINAL_MODEL_DIR = "saved_model_phase2"
81
+
82
+ os.makedirs(FINAL_MODEL_DIR, exist_ok=True)
83
+
84
+ # =========================
85
+ # LOAD MODEL
86
+ # =========================
87
+ processor = BlipProcessor.from_pretrained(
88
+ "Salesforce/blip-image-captioning-base"
89
+ )
90
+
91
+ model = BlipForConditionalGeneration.from_pretrained(
92
+ "Salesforce/blip-image-captioning-base"
93
+ )
94
+
95
+ # 🔥 Unfreeze LAST 2 vision layers only
96
+ for name, param in model.vision_model.named_parameters():
97
+ if "encoder.layers.10" in name or "encoder.layers.11" in name:
98
+ param.requires_grad = True
99
+ else:
100
+ param.requires_grad = False
101
+
102
+ model.gradient_checkpointing_enable()
103
+ model.config.use_cache = False
104
+ model.to(device)
105
+
106
+ # =========================
107
+ # DATASET SPLIT
108
+ # =========================
109
+ full_dataset = COCODataset(
110
+ "annotations/subset_10k.jsonl",
111
+ "train2017",
112
+ processor
113
+ )
114
+
115
+ train_size = int(0.9 * len(full_dataset))
116
+ val_size = len(full_dataset) - train_size
117
+
118
+ train_dataset, val_dataset = random_split(
119
+ full_dataset,
120
+ [train_size, val_size]
121
+ )
122
+
123
+ train_loader = DataLoader(
124
+ train_dataset,
125
+ batch_size=BATCH_SIZE,
126
+ shuffle=True,
127
+ num_workers=NUM_WORKERS
128
+ )
129
+
130
+ val_loader = DataLoader(
131
+ val_dataset,
132
+ batch_size=BATCH_SIZE,
133
+ shuffle=False,
134
+ num_workers=NUM_WORKERS
135
+ )
136
+
137
+ optimizer = AdamW(
138
+ filter(lambda p: p.requires_grad, model.parameters()),
139
+ lr=LR
140
+ )
141
+
142
+ scheduler = CosineAnnealingLR(optimizer, T_max=EPOCHS)
143
+
144
+ # =========================
145
+ # EARLY STOPPING
146
+ # =========================
147
+ best_cider = 0
148
+ patience = 3
149
+ counter = 0
150
+
151
+ # =========================
152
+ # TRAIN LOOP
153
+ # =========================
154
+ for epoch in range(EPOCHS):
155
+
156
+ model.train()
157
+ total_loss = 0
158
+ progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}")
159
+
160
+ for batch in progress_bar:
161
+ batch = {k: v.to(device) for k, v in batch.items()}
162
+
163
+ with torch.autocast(device_type="mps", dtype=torch.float16):
164
+ outputs = model(**batch)
165
+ loss = outputs.loss
166
+
167
+ loss.backward()
168
+ optimizer.step()
169
+ optimizer.zero_grad()
170
+
171
+ total_loss += loss.item()
172
+ progress_bar.set_postfix(loss=loss.item())
173
+
174
+ avg_train_loss = total_loss / len(train_loader)
175
+ print(f"Epoch {epoch+1} Train Loss: {avg_train_loss:.4f}")
176
+
177
+ # =========================
178
+ # VALIDATION LOSS
179
+ # =========================
180
+ model.eval()
181
+ val_loss = 0
182
+
183
+ with torch.no_grad():
184
+ for batch in val_loader:
185
+ batch = {k: v.to(device) for k, v in batch.items()}
186
+ outputs = model(**batch)
187
+ val_loss += outputs.loss.item()
188
+
189
+ val_loss /= len(val_loader)
190
+ print(f"Epoch {epoch+1} Validation Loss: {val_loss:.4f}")
191
+
192
+ # =========================
193
+ # CIDEr
194
+ # =========================
195
+ cider_score = evaluate_cider(model, processor, val_dataset, device)
196
+
197
+ # =========================
198
+ # SAVE BEST CIDEr MODEL
199
+ # =========================
200
+ if cider_score > best_cider:
201
+ best_cider = cider_score
202
+ counter = 0
203
+ model.save_pretrained(FINAL_MODEL_DIR)
204
+ processor.save_pretrained(FINAL_MODEL_DIR)
205
+ print("Best CIDEr model saved.")
206
+ else:
207
+ counter += 1
208
+
209
+ if counter >= patience:
210
+ print("Early stopping triggered.")
211
+ break
212
+
213
+ scheduler.step()
214
+
215
+ print("Phase 2 training complete.")
216
+
217
+
218
+ if __name__ == "__main__":
219
+ main()
train_vit_gpt2.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from torch.utils.data import DataLoader, random_split
4
+ from transformers import (
5
+ VisionEncoderDecoderModel,
6
+ ViTImageProcessor,
7
+ AutoTokenizer,
8
+ GPT2Config,
9
+ GPT2LMHeadModel,
10
+ ViTModel
11
+ )
12
+ from torch.optim import AdamW
13
+ from torch.optim.lr_scheduler import CosineAnnealingLR
14
+ from dataset_vit_gpt2 import COCODatasetViTGPT2
15
+ from tqdm import tqdm
16
+ from pycocoevalcap.cider.cider import Cider
17
+ from PIL import Image
18
+
19
+
20
+ # ==========================================
21
+ # GENERATE CAPTION
22
+ # ==========================================
23
+ def generate_caption(model, processor, tokenizer, image, device):
24
+
25
+ pixel_values = processor(images=image, return_tensors="pt").pixel_values.to(device)
26
+
27
+ with torch.no_grad():
28
+ output_ids = model.generate(
29
+ pixel_values=pixel_values,
30
+ num_beams=5,
31
+ max_length=20,
32
+ length_penalty=1.0,
33
+ pad_token_id=tokenizer.eos_token_id,
34
+ eos_token_id=tokenizer.eos_token_id
35
+ )
36
+
37
+ return tokenizer.decode(output_ids[0], skip_special_tokens=True)
38
+
39
+
40
+ # ==========================================
41
+ # CIDEr EVALUATION
42
+ # ==========================================
43
+ def evaluate_cider(model, processor, tokenizer, val_dataset, device, max_samples=200):
44
+
45
+ model.eval()
46
+ cider_scorer = Cider()
47
+
48
+ ground_truth = {}
49
+ predictions = {}
50
+
51
+ for idx in tqdm(range(min(max_samples, len(val_dataset))), desc="CIDEr Eval"):
52
+
53
+ real_idx = val_dataset.indices[idx]
54
+ ann = val_dataset.dataset.annotations[real_idx]
55
+
56
+ image_path = os.path.join("train2017", ann["image"])
57
+ image = Image.open(image_path).convert("RGB")
58
+
59
+ pred_caption = generate_caption(model, processor, tokenizer, image, device)
60
+
61
+ ground_truth[idx] = ann["captions"]
62
+ predictions[idx] = [pred_caption]
63
+
64
+ score, _ = cider_scorer.compute_score(ground_truth, predictions)
65
+
66
+ print(f"CIDEr Score: {score:.4f}")
67
+
68
+ model.train()
69
+ return score
70
+
71
+
72
+ # ==========================================
73
+ # MAIN
74
+ # ==========================================
75
+ def main():
76
+
77
+ device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
78
+ print("Using device:", device)
79
+
80
+ EPOCHS = 5
81
+ BATCH_SIZE = 6
82
+ LR = 3e-5
83
+ SAVE_DIR = "saved_vit_gpt2"
84
+
85
+ os.makedirs(SAVE_DIR, exist_ok=True)
86
+
87
+ # ------------------------------------------
88
+ # Build Encoder + Decoder
89
+ # ------------------------------------------
90
+
91
+ encoder = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
92
+
93
+ decoder_config = GPT2Config.from_pretrained("gpt2")
94
+ decoder_config.is_decoder = True
95
+ decoder_config.add_cross_attention = True
96
+
97
+ decoder = GPT2LMHeadModel.from_pretrained("gpt2", config=decoder_config)
98
+
99
+ model = VisionEncoderDecoderModel(
100
+ encoder=encoder,
101
+ decoder=decoder
102
+ )
103
+
104
+ processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
105
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
106
+
107
+ tokenizer.pad_token = tokenizer.eos_token
108
+
109
+ model.config.pad_token_id = tokenizer.eos_token_id
110
+ model.config.decoder_start_token_id = tokenizer.bos_token_id
111
+ model.config.eos_token_id = tokenizer.eos_token_id
112
+ model.config.vocab_size = model.config.decoder.vocab_size
113
+
114
+ model.to(device)
115
+
116
+ # ------------------------------------------
117
+ # DATASET
118
+ # ------------------------------------------
119
+
120
+ dataset = COCODatasetViTGPT2(
121
+ "annotations/subset_10k.jsonl",
122
+ "train2017",
123
+ processor,
124
+ tokenizer,
125
+ mode="short"
126
+ )
127
+
128
+ train_size = int(0.9 * len(dataset))
129
+ val_size = len(dataset) - train_size
130
+
131
+ train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
132
+
133
+ train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
134
+ val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)
135
+
136
+ optimizer = AdamW(model.parameters(), lr=LR)
137
+ scheduler = CosineAnnealingLR(optimizer, T_max=EPOCHS)
138
+
139
+ best_cider = 0
140
+
141
+ # ==========================================
142
+ # TRAIN LOOP
143
+ # ==========================================
144
+ for epoch in range(EPOCHS):
145
+
146
+ model.train()
147
+ total_loss = 0
148
+
149
+ for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
150
+
151
+ pixel_values = batch["pixel_values"].to(device)
152
+ labels = batch["labels"].to(device)
153
+
154
+ outputs = model(pixel_values=pixel_values, labels=labels)
155
+ loss = outputs.loss
156
+
157
+ loss.backward()
158
+
159
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
160
+
161
+ optimizer.step()
162
+ optimizer.zero_grad()
163
+
164
+ total_loss += loss.item()
165
+
166
+ avg_loss = total_loss / len(train_loader)
167
+ print(f"Train Loss: {avg_loss:.4f}")
168
+
169
+ # ------------------------------------------
170
+ # CIDEr Evaluation
171
+ # ------------------------------------------
172
+ cider_score = evaluate_cider(
173
+ model,
174
+ processor,
175
+ tokenizer,
176
+ val_dataset,
177
+ device
178
+ )
179
+
180
+ # Save best model
181
+ if cider_score > best_cider:
182
+ best_cider = cider_score
183
+ model.save_pretrained(SAVE_DIR)
184
+ tokenizer.save_pretrained(SAVE_DIR)
185
+ processor.save_pretrained(SAVE_DIR)
186
+ print("Best model saved.")
187
+
188
+ scheduler.step()
189
+
190
+ print("ViT-GPT2 Training complete.")
191
+
192
+
193
+ if __name__ == "__main__":
194
+ main()
uploadtohf.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import (
2
+ AutoTokenizer,
3
+ BlipForConditionalGeneration,
4
+ BlipProcessor,
5
+ GitForCausalLM,
6
+ GitProcessor,
7
+ VisionEncoderDecoderModel,
8
+ ViTImageProcessor,
9
+ )
10
+
11
+
12
+ def push_blip(
13
+ local_dir: str = "saved_model_phase2",
14
+ repo_id: str = "pchandragrid/blip-caption-model",
15
+ ) -> None:
16
+ model = BlipForConditionalGeneration.from_pretrained(local_dir)
17
+ processor = BlipProcessor.from_pretrained(local_dir)
18
+ model.push_to_hub(repo_id)
19
+ processor.push_to_hub(repo_id)
20
+
21
+
22
+ def push_vit_gpt2(
23
+ local_dir: str = "saved_vit_gpt2",
24
+ repo_id: str = "pchandragrid/vit-gpt2-caption-model",
25
+ ) -> None:
26
+ model = VisionEncoderDecoderModel.from_pretrained(local_dir)
27
+ image_processor = ViTImageProcessor.from_pretrained(local_dir)
28
+ tokenizer = AutoTokenizer.from_pretrained(local_dir)
29
+ model.push_to_hub(repo_id)
30
+ image_processor.push_to_hub(repo_id)
31
+ tokenizer.push_to_hub(repo_id)
32
+
33
+
34
+ def push_git(
35
+ local_dir: str = "saved_git_model",
36
+ repo_id: str = "pchandragrid/git-caption-model",
37
+ ) -> None:
38
+ model = GitForCausalLM.from_pretrained(local_dir)
39
+ processor = GitProcessor.from_pretrained(local_dir)
40
+ model.push_to_hub(repo_id)
41
+ processor.push_to_hub(repo_id)
42
+
43
+
44
+ if __name__ == "__main__":
45
+ push_blip()
46
+ push_vit_gpt2()
47
+ push_git()
48
+ print("Uploaded: BLIP, ViT-GPT2, and GIT models.")