Ryanfafa commited on
Commit
19ea5c5
·
verified ·
1 Parent(s): 346fd4f

Upload 7 files

Browse files
image_captioning/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Image captioning package: EfficientNetB0 encoder + GPT-2 decoder.
3
+
4
+ This package exposes the main components:
5
+ - ImageCaptioningModel (in model.py)
6
+ - dataset/dataloader utilities (in dataset.py)
7
+ - training, evaluation, and inference scripts.
8
+ """
9
+
10
+ from .model import ImageCaptioningModel # noqa: F401
11
+
image_captioning/config.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from dataclasses import dataclass
4
+ from typing import Optional
5
+
6
+ import numpy as np
7
+ import torch
8
+
9
+
10
+ @dataclass
11
+ class PathsConfig:
12
+ """
13
+ Configuration for dataset and checkpoint paths.
14
+
15
+ This is tailored to your existing visually impaired dataset layout:
16
+ - Images: <data_root>/visual_dataset/*.jpg
17
+ - Text: <data_root>/visual_text/visual.token.txt
18
+ <data_root>/visual_text/visual.trainImages.txt
19
+ <data_root>/visual_text/visual.testImages.txt
20
+ """
21
+
22
+ data_root: str = "/Users/ryan/Downloads/visuallyimpair"
23
+ images_dir_name: str = "visual_dataset"
24
+ text_dir_name: str = "visual_text"
25
+
26
+ def _join(self, *parts: str) -> str:
27
+ return os.path.join(*parts)
28
+
29
+ @property
30
+ def images_dir(self) -> str:
31
+ return self._join(self.data_root, self.images_dir_name)
32
+
33
+ @property
34
+ def text_dir(self) -> str:
35
+ return self._join(self.data_root, self.text_dir_name)
36
+
37
+ @property
38
+ def token_file(self) -> str:
39
+ return self._join(self.text_dir, "visual.token.txt")
40
+
41
+ @property
42
+ def train_list_file(self) -> str:
43
+ return self._join(self.text_dir, "visual.trainImages.txt")
44
+
45
+ @property
46
+ def test_list_file(self) -> str:
47
+ return self._join(self.text_dir, "visual.testImages.txt")
48
+
49
+
50
+ @dataclass
51
+ class TrainingConfig:
52
+ """
53
+ Hyperparameters and training-related configuration.
54
+ """
55
+
56
+ learning_rate: float = 5e-5
57
+ batch_size: int = 16
58
+ num_epochs: int = 10
59
+ warmup_steps: int = 500
60
+ max_caption_length: int = 50
61
+ gradient_accumulation_steps: int = 1
62
+ num_workers: int = 4
63
+ mixed_precision: bool = True
64
+ patience: int = 3
65
+ max_grad_norm: float = 1.0
66
+
67
+ # Model-specific
68
+ prefix_length: int = 1 # number of visual prefix tokens
69
+
70
+ # Logging / checkpoints
71
+ output_dir: str = "checkpoints"
72
+ log_dir: str = "runs"
73
+
74
+ # Reproducibility
75
+ seed: int = 42
76
+
77
+
78
+ def get_device() -> torch.device:
79
+ """
80
+ Return the best available device (CUDA if available, else CPU) and log it.
81
+ """
82
+
83
+ if torch.cuda.is_available():
84
+ device = torch.device("cuda")
85
+ print("Using CUDA for training/inference.")
86
+ else:
87
+ device = torch.device("cpu")
88
+ print("CUDA not available, falling back to CPU.")
89
+ return device
90
+
91
+
92
+ def set_seed(seed: int) -> None:
93
+ """
94
+ Set random seeds for reproducibility across Python, NumPy, and PyTorch.
95
+ """
96
+
97
+ random.seed(seed)
98
+ np.random.seed(seed)
99
+ torch.manual_seed(seed)
100
+ torch.cuda.manual_seed_all(seed)
101
+
102
+ torch.backends.cudnn.deterministic = True
103
+ torch.backends.cudnn.benchmark = False
104
+
105
+
106
+ def ensure_dir(path: str) -> None:
107
+ """
108
+ Create directory if it does not already exist.
109
+ """
110
+
111
+ os.makedirs(path, exist_ok=True)
112
+
image_captioning/dataset.py ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from typing import Dict, List, Optional, Tuple
4
+
5
+ import torch
6
+ from PIL import Image
7
+ from torch import Tensor
8
+ from torch.utils.data import DataLoader, Dataset, Subset
9
+ from torchvision import transforms
10
+ from transformers import GPT2TokenizerFast
11
+
12
+ from .config import PathsConfig, TrainingConfig
13
+
14
+
15
+ IMAGENET_MEAN = [0.485, 0.456, 0.406]
16
+ IMAGENET_STD = [0.229, 0.224, 0.225]
17
+
18
+
19
+ def train_image_transform() -> transforms.Compose:
20
+ """
21
+ Image preprocessing for training with random augmentation to improve
22
+ generalization. Augmentations are kept moderate to avoid changing the
23
+ semantic content of the scene.
24
+ """
25
+
26
+ return transforms.Compose(
27
+ [
28
+ transforms.Resize(256),
29
+ transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
30
+ transforms.RandomHorizontalFlip(p=0.5),
31
+ transforms.ColorJitter(
32
+ brightness=0.2,
33
+ contrast=0.2,
34
+ saturation=0.2,
35
+ hue=0.05,
36
+ ),
37
+ transforms.ToTensor(),
38
+ transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
39
+ ]
40
+ )
41
+
42
+
43
+ def eval_image_transform() -> transforms.Compose:
44
+ """
45
+ Deterministic preprocessing for validation and test: resize, center-crop
46
+ to 224x224, normalize.
47
+ """
48
+
49
+ return transforms.Compose(
50
+ [
51
+ transforms.Resize(256),
52
+ transforms.CenterCrop(224),
53
+ transforms.ToTensor(),
54
+ transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
55
+ ]
56
+ )
57
+
58
+
59
+ class ImageCaptionDataset(Dataset):
60
+ """
61
+ Custom Dataset for the visually impaired image captioning data.
62
+
63
+ This implementation is tailored to your existing layout:
64
+ - Images: <data_root>/visual_dataset/*.jpg
65
+ - Text:
66
+ - visual.token.txt (image#idx<TAB>caption)
67
+ - visual.trainImages.txt (one image filename per line)
68
+ - visual.testImages.txt (one image filename per line)
69
+ """
70
+
71
+ def __init__(
72
+ self,
73
+ paths_cfg: PathsConfig,
74
+ tokenizer: GPT2TokenizerFast,
75
+ split: str = "train",
76
+ training_cfg: Optional[TrainingConfig] = None,
77
+ transform: Optional[transforms.Compose] = None,
78
+ random_caption: bool = True,
79
+ ) -> None:
80
+ super().__init__()
81
+
82
+ if split not in {"train", "val", "test"}:
83
+ raise ValueError("split must be one of {'train', 'val', 'test'}")
84
+
85
+ self.paths_cfg = paths_cfg
86
+ self.tokenizer = tokenizer
87
+ self.training_cfg = training_cfg or TrainingConfig()
88
+ # If no transform is provided, fall back to a deterministic eval
89
+ # transform so this class can still be used directly. In practice,
90
+ # create_dataloader() will supply train/eval-specific transforms.
91
+ self.transform = transform or eval_image_transform()
92
+ self.random_caption = random_caption
93
+
94
+ self.max_length: int = int(self.training_cfg.max_caption_length)
95
+
96
+ # Load all captions from visual.token.txt
97
+ token_path = self.paths_cfg.token_file
98
+ if not os.path.exists(token_path):
99
+ raise FileNotFoundError(f"Caption file not found: {token_path}")
100
+
101
+ self.captions_by_image: Dict[str, List[str]] = {}
102
+ with open(token_path, "r", encoding="utf-8") as f:
103
+ for line in f:
104
+ line = line.strip()
105
+ if not line:
106
+ continue
107
+ try:
108
+ key, caption = line.split("\t", 1)
109
+ except ValueError as exc:
110
+ raise ValueError(f"Malformed line in {token_path}: {line}") from exc
111
+
112
+ img_name = key.split("#")[0]
113
+ self.captions_by_image.setdefault(img_name, []).append(caption.strip())
114
+
115
+ # Choose image list file based on split
116
+ if split == "train":
117
+ list_file = self.paths_cfg.train_list_file
118
+ else:
119
+ # We only have a single test list in this dataset; use it for both
120
+ # 'val' and 'test' splits for now.
121
+ list_file = self.paths_cfg.test_list_file
122
+
123
+ if not os.path.exists(list_file):
124
+ raise FileNotFoundError(f"Image list file for split '{split}' not found: {list_file}")
125
+
126
+ self.image_ids: List[str] = []
127
+ with open(list_file, "r", encoding="utf-8") as f:
128
+ for line in f:
129
+ img_name = line.strip()
130
+ if not img_name:
131
+ continue
132
+ if img_name not in self.captions_by_image:
133
+ # Skip images without captions to avoid runtime issues
134
+ continue
135
+ self.image_ids.append(img_name)
136
+
137
+ if not self.image_ids:
138
+ raise RuntimeError(f"No images with captions found for split '{split}'.")
139
+
140
+ print(f"Loaded {len(self.image_ids)} {split} images with captions.")
141
+
142
+ def __len__(self) -> int:
143
+ return len(self.image_ids)
144
+
145
+ def __getitem__(self, idx: int) -> Dict[str, Tensor]:
146
+ img_name = self.image_ids[idx]
147
+ img_path = os.path.join(self.paths_cfg.images_dir, img_name)
148
+
149
+ if not os.path.exists(img_path):
150
+ raise FileNotFoundError(f"Image file not found: {img_path}")
151
+
152
+ image = Image.open(img_path).convert("RGB")
153
+ image_tensor = self.transform(image)
154
+
155
+ caption_list = self.captions_by_image[img_name]
156
+ if not caption_list:
157
+ raise RuntimeError(f"No captions available for image {img_name}")
158
+
159
+ # Choose a caption. During training we consider up to three different
160
+ # captions per image and randomly sample among them; for evaluation we
161
+ # always take the first caption. We only strip leading/trailing
162
+ # whitespace so that the raw textual content is preserved and no
163
+ # characters are dropped before tokenization.
164
+ if self.random_caption:
165
+ limited_captions = caption_list[:3]
166
+ caption = random.choice(limited_captions)
167
+ else:
168
+ caption = caption_list[0]
169
+ caption = caption.strip()
170
+
171
+ # Convert caption text into token IDs without adding any extra special
172
+ # tokens so we retain a direct mapping between the raw caption string
173
+ # and the token sequence.
174
+ token_ids: List[int] = self.tokenizer.encode(
175
+ caption,
176
+ add_special_tokens=False,
177
+ )
178
+
179
+ # Define explicit BOS (start-of-sentence) and EOS (end-of-sentence)
180
+ # tokens so the model learns where captions begin and end. If the
181
+ # tokenizer does not define a BOS token, we reuse EOS.
182
+ bos_token_id = self.tokenizer.bos_token_id or self.tokenizer.eos_token_id
183
+ eos_token_id = self.tokenizer.eos_token_id
184
+
185
+ seq_ids: List[int] = [bos_token_id] + token_ids + [eos_token_id]
186
+
187
+ # Truncate if necessary to respect max_length. To guarantee that the
188
+ # full caption (including BOS/EOS) can be represented without cutting
189
+ # tokens, ensure that training_cfg.max_caption_length is set large
190
+ # enough for your data.
191
+ if len(seq_ids) > self.max_length:
192
+ seq_ids = seq_ids[: self.max_length]
193
+
194
+ # Pad up to max_length with pad_token_id and build attention mask.
195
+ pad_id = self.tokenizer.pad_token_id
196
+ input_ids = torch.full(
197
+ (self.max_length,),
198
+ pad_id,
199
+ dtype=torch.long,
200
+ )
201
+ attention_mask = torch.zeros(self.max_length, dtype=torch.long)
202
+
203
+ seq_len = len(seq_ids)
204
+ input_ids[:seq_len] = torch.tensor(seq_ids, dtype=torch.long)
205
+ attention_mask[:seq_len] = 1
206
+
207
+ # Labels are initially the same as input_ids; padding positions will
208
+ # be set to -100 so they are ignored by the loss.
209
+ labels = input_ids.clone()
210
+ labels[attention_mask == 0] = -100
211
+
212
+ return {
213
+ "image": image_tensor,
214
+ "input_ids": input_ids,
215
+ "attention_mask": attention_mask,
216
+ "labels": labels,
217
+ "caption": caption,
218
+ "image_id": img_name,
219
+ }
220
+
221
+
222
+ def create_tokenizer() -> GPT2TokenizerFast:
223
+ """
224
+ Create a GPT-2 tokenizer with a defined pad token.
225
+ """
226
+
227
+ tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
228
+ if tokenizer.pad_token is None:
229
+ tokenizer.pad_token = tokenizer.eos_token
230
+ return tokenizer
231
+
232
+
233
+ def _infer_category_from_filename(filename: str) -> str:
234
+ """
235
+ Infer a coarse category label from an image filename.
236
+
237
+ Heuristic:
238
+ - Strip directory and extension.
239
+ - Remove trailing digits to group files like 'bench1.jpg', 'bench25.jpg'
240
+ into the same category 'bench'.
241
+ """
242
+
243
+ base = os.path.basename(filename)
244
+ stem, _ext = os.path.splitext(base)
245
+
246
+ # Remove trailing digits
247
+ i = len(stem)
248
+ while i > 0 and stem[i - 1].isdigit():
249
+ i -= 1
250
+ category = stem[:i] or stem
251
+
252
+ return category
253
+
254
+
255
+ def _balanced_train_val_indices(
256
+ dataset: ImageCaptionDataset,
257
+ val_ratio: float = 0.2,
258
+ ) -> Tuple[List[int], List[int]]:
259
+ """
260
+ Split the dataset indices into train and validation sets.
261
+
262
+ The validation set:
263
+ - Targets approximately `val_ratio` of the total dataset size.
264
+ - Is balanced across categories inferred from filenames, i.e., each
265
+ category contributes (as much as possible) the same number of images.
266
+ """
267
+
268
+ num_items = len(dataset.image_ids)
269
+ if num_items == 0:
270
+ raise RuntimeError("Cannot create train/val split from an empty dataset.")
271
+
272
+ # Group indices by inferred category
273
+ category_to_indices: Dict[str, List[int]] = {}
274
+ for idx, img_name in enumerate(dataset.image_ids):
275
+ cat = _infer_category_from_filename(img_name)
276
+ category_to_indices.setdefault(cat, []).append(idx)
277
+
278
+ # Sort indices within each category for deterministic behavior
279
+ for indices in category_to_indices.values():
280
+ indices.sort()
281
+
282
+ categories = sorted(category_to_indices.keys())
283
+ num_categories = len(categories)
284
+
285
+ # Desired total size for validation set
286
+ target_val_size = max(1, int(round(val_ratio * num_items)))
287
+
288
+ # Base number of validation samples per category, constrained by the
289
+ # smallest category so we can keep counts balanced.
290
+ min_cat_size = min(len(category_to_indices[cat]) for cat in categories)
291
+ per_category = min(
292
+ min_cat_size,
293
+ max(1, int(round(target_val_size / max(1, num_categories)))),
294
+ )
295
+
296
+ val_indices: List[int] = []
297
+ train_indices: List[int] = []
298
+
299
+ for cat in categories:
300
+ indices = category_to_indices[cat]
301
+ val_for_cat = indices[:per_category]
302
+ train_for_cat = indices[per_category:]
303
+ val_indices.extend(val_for_cat)
304
+ train_indices.extend(train_for_cat)
305
+
306
+ return train_indices, val_indices
307
+
308
+
309
+ def create_dataloader(
310
+ paths_cfg: PathsConfig,
311
+ training_cfg: TrainingConfig,
312
+ split: str,
313
+ tokenizer: Optional[GPT2TokenizerFast] = None,
314
+ shuffle: Optional[bool] = None,
315
+ ) -> Tuple[DataLoader, GPT2TokenizerFast]:
316
+ """
317
+ Factory function to create a DataLoader for a given split.
318
+
319
+ Parameters
320
+ ----------
321
+ paths_cfg:
322
+ Paths configuration.
323
+ training_cfg:
324
+ Training configuration containing batch size, max caption length, etc.
325
+ split:
326
+ One of {'train', 'val', 'test'}.
327
+ tokenizer:
328
+ Optional pre-initialized GPT-2 tokenizer. If None, a new one is created.
329
+ shuffle:
330
+ Optional flag to override shuffle behavior. If None, shuffle is True
331
+ for the 'train' split and False otherwise.
332
+ """
333
+
334
+ if tokenizer is None:
335
+ tokenizer = create_tokenizer()
336
+
337
+ if shuffle is None:
338
+ shuffle = split == "train"
339
+
340
+ # For training and validation, we build a single underlying dataset from
341
+ # the training list file and then create a balanced 80/20 split by
342
+ # category. The test split continues to use the dedicated test list file.
343
+ if split == "test":
344
+ random_caption = False
345
+ dataset = ImageCaptionDataset(
346
+ paths_cfg=paths_cfg,
347
+ tokenizer=tokenizer,
348
+ split="test",
349
+ training_cfg=training_cfg,
350
+ transform=eval_image_transform(),
351
+ random_caption=random_caption,
352
+ )
353
+ else:
354
+ # Underlying full training dataset
355
+ full_train_dataset = ImageCaptionDataset(
356
+ paths_cfg=paths_cfg,
357
+ tokenizer=tokenizer,
358
+ split="train",
359
+ training_cfg=training_cfg,
360
+ transform=train_image_transform(),
361
+ random_caption=True, # always randomize captions during training
362
+ )
363
+
364
+ train_indices, val_indices = _balanced_train_val_indices(
365
+ full_train_dataset,
366
+ val_ratio=0.2,
367
+ )
368
+
369
+ if split == "train":
370
+ dataset = Subset(full_train_dataset, train_indices)
371
+ elif split == "val":
372
+ dataset = Subset(full_train_dataset, val_indices)
373
+ else:
374
+ raise ValueError("split must be one of {'train', 'val', 'test'}")
375
+
376
+ dataloader = DataLoader(
377
+ dataset,
378
+ batch_size=training_cfg.batch_size,
379
+ shuffle=shuffle,
380
+ num_workers=training_cfg.num_workers,
381
+ pin_memory=torch.cuda.is_available(),
382
+ )
383
+
384
+ return dataloader, tokenizer
385
+
image_captioning/evaluate.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ from typing import Dict, List, Tuple
5
+
6
+ import torch
7
+ from nltk.translate.bleu_score import SmoothingFunction, corpus_bleu
8
+ from nltk.translate.meteor_score import single_meteor_score
9
+ from rouge_score import rouge_scorer
10
+
11
+ from .config import PathsConfig, TrainingConfig, get_device, set_seed
12
+ from .dataset import create_dataloader, create_tokenizer
13
+ from .model import ImageCaptioningModel
14
+
15
+
16
+ def parse_args() -> argparse.Namespace:
17
+ """
18
+ Parse command-line arguments for evaluation.
19
+ """
20
+
21
+ parser = argparse.ArgumentParser(description="Evaluate image captioning model on test set.")
22
+ parser.add_argument("--data_root", type=str, default="/Users/ryan/Downloads/visuallyimpair", help="Root path to dataset.")
23
+ parser.add_argument("--checkpoint", type=str, required=True, help="Path to model checkpoint (.pt).")
24
+ parser.add_argument("--batch_size", type=int, default=16, help="Batch size for evaluation.")
25
+ parser.add_argument("--max_length", type=int, default=50, help="Maximum caption length during generation.")
26
+ parser.add_argument("--num_beams", type=int, default=3, help="Number of beams for beam search.")
27
+ parser.add_argument("--seed", type=int, default=42, help="Random seed.")
28
+ parser.add_argument("--output_samples", type=str, default="evaluation_samples.jsonl", help="File to save sample predictions.")
29
+ return parser.parse_args()
30
+
31
+
32
+ def compute_metrics(
33
+ references: List[List[str]],
34
+ hypotheses: List[str],
35
+ ) -> Dict[str, float]:
36
+ """
37
+ Compute BLEU (1-4), METEOR, and ROUGE-L metrics.
38
+ """
39
+
40
+ if not references or not hypotheses:
41
+ raise ValueError("References and hypotheses must be non-empty.")
42
+ if len(references) != len(hypotheses):
43
+ raise ValueError("Number of references and hypotheses must match.")
44
+
45
+ smoothie = SmoothingFunction().method4
46
+
47
+ # BLEU scores
48
+ bleu1 = corpus_bleu(
49
+ references,
50
+ hypotheses,
51
+ weights=(1.0, 0.0, 0.0, 0.0),
52
+ smoothing_function=smoothie,
53
+ )
54
+ bleu2 = corpus_bleu(
55
+ references,
56
+ hypotheses,
57
+ weights=(0.5, 0.5, 0.0, 0.0),
58
+ smoothing_function=smoothie,
59
+ )
60
+ bleu3 = corpus_bleu(
61
+ references,
62
+ hypotheses,
63
+ weights=(1.0 / 3, 1.0 / 3, 1.0 / 3, 0.0),
64
+ smoothing_function=smoothie,
65
+ )
66
+ bleu4 = corpus_bleu(
67
+ references,
68
+ hypotheses,
69
+ weights=(0.25, 0.25, 0.25, 0.25),
70
+ smoothing_function=smoothie,
71
+ )
72
+
73
+ # METEOR
74
+ meteor_scores: List[float] = []
75
+ for ref_list, hyp in zip(references, hypotheses):
76
+ # Use the first reference for METEOR; tokenize by simple whitespace.
77
+ # If NLTK's WordNet data is missing, fall back to a simple unigram F1.
78
+ ref_tokens = ref_list[0].split()
79
+ hyp_tokens = hyp.split()
80
+ try:
81
+ meteor_scores.append(single_meteor_score(ref_tokens, hyp_tokens))
82
+ except LookupError:
83
+ ref_set = set(ref_tokens)
84
+ hyp_set = set(hyp_tokens)
85
+ if not ref_set or not hyp_set:
86
+ meteor_scores.append(0.0)
87
+ else:
88
+ overlap = len(ref_set & hyp_set)
89
+ precision = overlap / len(hyp_set)
90
+ recall = overlap / len(ref_set)
91
+ if precision + recall == 0:
92
+ meteor_scores.append(0.0)
93
+ else:
94
+ meteor_scores.append(2 * precision * recall / (precision + recall))
95
+ meteor = sum(meteor_scores) / max(1, len(meteor_scores))
96
+
97
+ # ROUGE-L
98
+ rouge = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=True)
99
+ rouge_l_scores: List[float] = []
100
+ for ref_list, hyp in zip(references, hypotheses):
101
+ scores = rouge.score(ref_list[0], hyp)
102
+ rouge_l_scores.append(scores["rougeL"].fmeasure)
103
+ rouge_l = sum(rouge_l_scores) / max(1, len(rouge_l_scores))
104
+
105
+ return {
106
+ "BLEU-1": bleu1,
107
+ "BLEU-2": bleu2,
108
+ "BLEU-3": bleu3,
109
+ "BLEU-4": bleu4,
110
+ "METEOR": meteor,
111
+ "ROUGE-L": rouge_l,
112
+ }
113
+
114
+
115
+ def run_evaluation(args: argparse.Namespace) -> None:
116
+ """
117
+ Run evaluation on the test set, compute metrics, and save sample predictions.
118
+ """
119
+
120
+ paths_cfg = PathsConfig(data_root=args.data_root)
121
+ training_cfg = TrainingConfig(
122
+ batch_size=args.batch_size,
123
+ max_caption_length=args.max_length,
124
+ num_epochs=1,
125
+ )
126
+
127
+ set_seed(args.seed)
128
+ device = get_device()
129
+
130
+ tokenizer = create_tokenizer()
131
+ test_loader, tokenizer = create_dataloader(
132
+ paths_cfg=paths_cfg,
133
+ training_cfg=training_cfg,
134
+ split="test",
135
+ tokenizer=tokenizer,
136
+ shuffle=False,
137
+ )
138
+
139
+ model = ImageCaptioningModel(training_cfg=training_cfg)
140
+ state_dict = torch.load(args.checkpoint, map_location=device)
141
+ model.load_state_dict(state_dict)
142
+ model.to(device)
143
+ model.eval()
144
+
145
+ references: List[List[str]] = []
146
+ hypotheses: List[str] = []
147
+
148
+ num_samples_to_save = 50
149
+ saved_samples: List[Dict[str, str]] = []
150
+
151
+ with torch.no_grad():
152
+ for batch in test_loader:
153
+ images = batch["image"].to(device)
154
+ # Use the raw caption string from the dataset as reference
155
+ captions = batch["caption"]
156
+
157
+ # Generate predictions one image at a time to respect generate() constraints
158
+ for idx in range(images.size(0)):
159
+ single_image = images[idx : idx + 1]
160
+ ref_caption = captions[idx]
161
+
162
+ pred_text_list = model.generate(
163
+ images=single_image,
164
+ max_length=args.max_length,
165
+ num_beams=args.num_beams,
166
+ )
167
+ pred_text = pred_text_list[0]
168
+
169
+ references.append([ref_caption])
170
+ hypotheses.append(pred_text)
171
+
172
+ if len(saved_samples) < num_samples_to_save:
173
+ saved_samples.append(
174
+ {
175
+ "image_id": batch["image_id"][idx],
176
+ "reference": ref_caption,
177
+ "prediction": pred_text,
178
+ }
179
+ )
180
+
181
+ metrics = compute_metrics(references, hypotheses)
182
+
183
+ print("Evaluation metrics:")
184
+ for name, value in metrics.items():
185
+ print(f" {name}: {value:.4f}")
186
+
187
+ # Save sample predictions
188
+ output_path = args.output_samples
189
+ with open(output_path, "w", encoding="utf-8") as f:
190
+ for sample in saved_samples:
191
+ f.write(json.dumps(sample) + "\n")
192
+
193
+ print(f"Saved {len(saved_samples)} sample predictions to {output_path}")
194
+
195
+
196
+ def main() -> None:
197
+ args = parse_args()
198
+
199
+ if not os.path.exists(args.checkpoint):
200
+ raise FileNotFoundError(f"Checkpoint not found: {args.checkpoint}")
201
+
202
+ run_evaluation(args)
203
+
204
+
205
+ if __name__ == "__main__":
206
+ main()
207
+
image_captioning/inference.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from typing import List
4
+
5
+ import torch
6
+ from PIL import Image
7
+ from torchvision import transforms
8
+
9
+ from .config import PathsConfig, TrainingConfig, get_device, set_seed
10
+ from .dataset import IMAGENET_MEAN, IMAGENET_STD, create_tokenizer
11
+ from .model import ImageCaptioningModel
12
+
13
+
14
+ def parse_args() -> argparse.Namespace:
15
+ """
16
+ Parse command-line arguments for inference.
17
+ """
18
+
19
+ parser = argparse.ArgumentParser(description="Run image captioning inference on a single image.")
20
+ parser.add_argument("--image", type=str, required=True, help="Path to image file.")
21
+ parser.add_argument("--checkpoint", type=str, default="checkpoints/best_model.pt", help="Path to model checkpoint.")
22
+ parser.add_argument("--max_length", type=int, default=50, help="Maximum caption length.")
23
+ parser.add_argument("--num_beams", type=int, default=3, help="Number of beams for beam search.")
24
+ parser.add_argument("--seed", type=int, default=42, help="Random seed.")
25
+ parser.add_argument("--data_root", type=str, default="/Users/ryan/Downloads/visuallyimpair", help="Root path to dataset (for consistency).")
26
+ return parser.parse_args()
27
+
28
+
29
+ def build_preprocess_transform() -> transforms.Compose:
30
+ """
31
+ Build image preprocessing transform matching the training pipeline.
32
+ """
33
+
34
+ return transforms.Compose(
35
+ [
36
+ transforms.Resize(256),
37
+ transforms.CenterCrop(224),
38
+ transforms.ToTensor(),
39
+ transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
40
+ ]
41
+ )
42
+
43
+
44
+ def load_image(image_path: str) -> torch.Tensor:
45
+ """
46
+ Load and preprocess a single image.
47
+ """
48
+
49
+ if not os.path.exists(image_path):
50
+ raise FileNotFoundError(f"Image not found: {image_path}")
51
+
52
+ image = Image.open(image_path).convert("RGB")
53
+ transform = build_preprocess_transform()
54
+ tensor = transform(image).unsqueeze(0) # (1, 3, 224, 224)
55
+ return tensor
56
+
57
+
58
+ def run_inference(args: argparse.Namespace) -> List[str]:
59
+ """
60
+ Run caption generation on the specified image and print the result.
61
+ """
62
+
63
+ set_seed(args.seed)
64
+ device = get_device()
65
+
66
+ _paths_cfg = PathsConfig(data_root=args.data_root) # Included for consistency and future extensions
67
+ training_cfg = TrainingConfig(max_caption_length=args.max_length)
68
+
69
+ tokenizer = create_tokenizer()
70
+
71
+ model = ImageCaptioningModel(training_cfg=training_cfg)
72
+ if not os.path.exists(args.checkpoint):
73
+ raise FileNotFoundError(f"Checkpoint not found: {args.checkpoint}")
74
+
75
+ state_dict = torch.load(args.checkpoint, map_location=device)
76
+ model.load_state_dict(state_dict)
77
+ model.to(device)
78
+ model.eval()
79
+
80
+ image_tensor = load_image(args.image).to(device)
81
+
82
+ captions = model.generate(
83
+ images=image_tensor,
84
+ max_length=args.max_length,
85
+ num_beams=args.num_beams,
86
+ )
87
+
88
+ for caption in captions:
89
+ print(f"Caption: {caption}")
90
+
91
+ return captions
92
+
93
+
94
+ def main() -> None:
95
+ args = parse_args()
96
+ run_inference(args)
97
+
98
+
99
+ if __name__ == "__main__":
100
+ main()
101
+
image_captioning/model.py ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ from dataclasses import dataclass
3
+ from typing import List, Optional, Tuple
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch import Tensor
8
+ from torchvision import models
9
+ from transformers import GPT2LMHeadModel, GPT2TokenizerFast
10
+
11
+ from .config import TrainingConfig, get_device
12
+
13
+
14
+ @dataclass
15
+ class ImageCaptioningOutput:
16
+ """
17
+ Container for model outputs.
18
+
19
+ Attributes
20
+ ----------
21
+ logits:
22
+ Predicted token logits of shape (batch_size, seq_len, vocab_size),
23
+ where seq_len is the number of text tokens (visual prefix tokens are removed).
24
+ loss:
25
+ Optional cross-entropy loss over caption tokens.
26
+ """
27
+
28
+ logits: Tensor
29
+ loss: Optional[Tensor] = None
30
+
31
+
32
+ class EfficientNetB0Encoder(nn.Module):
33
+ """
34
+ EfficientNet-B0 image encoder using torchvision.
35
+
36
+ The classification head is removed and only the pooled feature vector
37
+ (dimension 1280) is returned.
38
+ """
39
+
40
+ def __init__(self, pretrained: bool = True) -> None:
41
+ super().__init__()
42
+ effnet = models.efficientnet_b0(pretrained=pretrained)
43
+ self.features = effnet.features
44
+ self.avgpool = effnet.avgpool
45
+ self.flatten = nn.Flatten()
46
+ # in_features of the final classifier is the encoder output dim
47
+ self.out_dim: int = effnet.classifier[1].in_features
48
+
49
+ def forward(self, images: Tensor) -> Tensor:
50
+ """
51
+ Encode a batch of images into a pooled feature representation.
52
+
53
+ Parameters
54
+ ----------
55
+ images:
56
+ Tensor of shape (batch_size, 3, 224, 224).
57
+ """
58
+
59
+ x = self.features(images)
60
+ x = self.avgpool(x)
61
+ x = self.flatten(x) # (batch_size, out_dim)
62
+ return x
63
+
64
+
65
+ class ImageCaptioningModel(nn.Module):
66
+ """
67
+ Image captioning model with an EfficientNet-B0 vision encoder and GPT-2 decoder.
68
+
69
+ The model projects visual features into a sequence of prefix embeddings that
70
+ are concatenated with GPT-2 token embeddings. GPT-2 then predicts caption tokens.
71
+ """
72
+
73
+ def __init__(
74
+ self,
75
+ training_cfg: Optional[TrainingConfig] = None,
76
+ pretrained_encoder: bool = True,
77
+ ) -> None:
78
+ super().__init__()
79
+
80
+ self.training_cfg = training_cfg or TrainingConfig()
81
+ self.device: torch.device = get_device()
82
+
83
+ # Vision encoder
84
+ self.encoder = EfficientNetB0Encoder(pretrained=pretrained_encoder)
85
+
86
+ # Text decoder (GPT-2 small)
87
+ self.tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
88
+ if self.tokenizer.pad_token is None:
89
+ # Use EOS as pad token
90
+ self.tokenizer.pad_token = self.tokenizer.eos_token
91
+
92
+ self.gpt2 = GPT2LMHeadModel.from_pretrained("gpt2")
93
+ self.gpt2.config.pad_token_id = self.tokenizer.pad_token_id
94
+
95
+ # Number of visual prefix tokens
96
+ self.prefix_length: int = int(self.training_cfg.prefix_length)
97
+ if self.prefix_length < 1:
98
+ raise ValueError("prefix_length must be >= 1")
99
+
100
+ # Project image features to a sequence of prefix token embeddings
101
+ self.visual_projection = nn.Linear(
102
+ self.encoder.out_dim,
103
+ self.gpt2.config.n_embd * self.prefix_length,
104
+ )
105
+
106
+ self._printed_debug: bool = False
107
+
108
+ self.to(self.device)
109
+
110
+ # --------------------------------------------------------------------- #
111
+ # Internal utilities
112
+ # --------------------------------------------------------------------- #
113
+ def encode_images(self, images: Tensor) -> Tensor:
114
+ """
115
+ Encode images and produce visual prefix embeddings.
116
+
117
+ Returns
118
+ -------
119
+ Tensor of shape (batch_size, prefix_length, hidden_size).
120
+ """
121
+
122
+ assert images.dim() == 4, f"Expected images of shape (B,3,H,W), got {images.shape}"
123
+ img_features = self.encoder(images) # (B, encoder_out_dim)
124
+ batch_size = img_features.size(0)
125
+
126
+ prefix_embeddings = self.visual_projection(img_features)
127
+ prefix_embeddings = prefix_embeddings.view(
128
+ batch_size,
129
+ self.prefix_length,
130
+ self.gpt2.config.n_embd,
131
+ )
132
+ return prefix_embeddings
133
+
134
+ # --------------------------------------------------------------------- #
135
+ # Forward (training)
136
+ # --------------------------------------------------------------------- #
137
+ def forward(
138
+ self,
139
+ images: Tensor,
140
+ captions: Tensor,
141
+ attention_mask: Optional[Tensor] = None,
142
+ labels: Optional[Tensor] = None,
143
+ ) -> ImageCaptioningOutput:
144
+ """
145
+ Forward pass for training.
146
+
147
+ Parameters
148
+ ----------
149
+ images:
150
+ Tensor of shape (batch_size, 3, 224, 224).
151
+ captions:
152
+ Token IDs of shape (batch_size, seq_len).
153
+ attention_mask:
154
+ Optional attention mask of shape (batch_size, seq_len).
155
+ labels:
156
+ Optional target token IDs of shape (batch_size, seq_len).
157
+ If provided, cross-entropy loss is computed, ignoring positions
158
+ with label -100.
159
+ """
160
+
161
+ images = images.to(self.device)
162
+ captions = captions.to(self.device)
163
+ if attention_mask is not None:
164
+ attention_mask = attention_mask.to(self.device)
165
+ if labels is not None:
166
+ labels = labels.to(self.device)
167
+
168
+ batch_size, seq_len = captions.shape
169
+ assert images.size(0) == batch_size, "Batch size mismatch between images and captions."
170
+
171
+ prefix_embeddings = self.encode_images(images) # (B, P, H)
172
+
173
+ token_embeddings = self.gpt2.transformer.wte(captions) # (B, T, H)
174
+ inputs_embeds = torch.cat([prefix_embeddings, token_embeddings], dim=1) # (B, P+T, H)
175
+
176
+ if attention_mask is not None:
177
+ prefix_mask = torch.ones(
178
+ batch_size,
179
+ self.prefix_length,
180
+ dtype=attention_mask.dtype,
181
+ device=attention_mask.device,
182
+ )
183
+ extended_attention_mask = torch.cat([prefix_mask, attention_mask], dim=1)
184
+ else:
185
+ extended_attention_mask = None
186
+
187
+ if not self._printed_debug:
188
+ print(f"[DEBUG] images shape: {images.shape}")
189
+ print(f"[DEBUG] captions shape: {captions.shape}")
190
+ print(f"[DEBUG] prefix_embeddings: {prefix_embeddings.shape}")
191
+ print(f"[DEBUG] token_embeddings: {token_embeddings.shape}")
192
+ print(f"[DEBUG] inputs_embeds shape: {inputs_embeds.shape}")
193
+ if extended_attention_mask is not None:
194
+ print(f"[DEBUG] attention_mask shape: {extended_attention_mask.shape}")
195
+ self._printed_debug = True
196
+
197
+ outputs = self.gpt2(
198
+ inputs_embeds=inputs_embeds,
199
+ attention_mask=extended_attention_mask,
200
+ use_cache=False,
201
+ return_dict=True,
202
+ )
203
+
204
+ # Remove visual prefix positions from the logits so that
205
+ # the returned logits only correspond to text tokens.
206
+ logits = outputs.logits[:, self.prefix_length :, :] # (B, T, V)
207
+
208
+ loss: Optional[Tensor] = None
209
+ if labels is not None:
210
+ if labels.shape != (batch_size, seq_len):
211
+ raise ValueError(
212
+ f"labels shape {labels.shape} does not match captions shape {(batch_size, seq_len)}"
213
+ )
214
+
215
+ # Shift logits and labels for next-token prediction
216
+ shift_logits = logits[:, :-1, :].contiguous()
217
+ shift_labels = labels[:, 1:].contiguous()
218
+
219
+ loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
220
+ loss = loss_fct(
221
+ shift_logits.view(-1, shift_logits.size(-1)),
222
+ shift_labels.view(-1),
223
+ )
224
+
225
+ return ImageCaptioningOutput(logits=logits, loss=loss)
226
+
227
+ # --------------------------------------------------------------------- #
228
+ # Generation (inference)
229
+ # --------------------------------------------------------------------- #
230
+ @torch.no_grad()
231
+ def generate(
232
+ self,
233
+ images: Tensor,
234
+ max_length: int = 50,
235
+ num_beams: int = 1,
236
+ temperature: float = 1.0,
237
+ top_k: int = 0,
238
+ eos_token_id: Optional[int] = None,
239
+ length_penalty: float = 0.0,
240
+ repetition_penalty: float = 1.0,
241
+ ) -> List[str]:
242
+ """
243
+ Generate captions for a batch of images using a simple beam search.
244
+
245
+ Notes
246
+ -----
247
+ - For simplicity and clarity, this implementation currently supports
248
+ batch_size == 1. A ValueError is raised otherwise.
249
+ """
250
+
251
+ self.eval()
252
+
253
+ images = images.to(self.device)
254
+ batch_size = images.size(0)
255
+ if batch_size != 1:
256
+ raise ValueError(f"generate currently supports batch_size == 1, got {batch_size}")
257
+
258
+ eos_token_id = eos_token_id or self.tokenizer.eos_token_id
259
+ bos_token_id = self.tokenizer.bos_token_id or self.tokenizer.eos_token_id
260
+
261
+ prefix_embeddings = self.encode_images(images) # (1, P, H)
262
+
263
+ # Each beam is (token_ids, log_prob)
264
+ beams: List[Tuple[List[int], float]] = [([], 0.0)]
265
+
266
+ def _length_normalized_score(tokens: List[int], score: float) -> float:
267
+ if length_penalty is None or length_penalty == 0.0:
268
+ return score
269
+ length = max(1, len(tokens))
270
+ return score / (length ** length_penalty)
271
+
272
+ for _ in range(max_length):
273
+ all_candidates: List[Tuple[List[int], float]] = []
274
+ for seq, score in beams:
275
+ if seq and seq[-1] == eos_token_id:
276
+ # If already finished, keep as-is
277
+ all_candidates.append((seq, score))
278
+ continue
279
+
280
+ # Build a 2D tensor of token IDs with shape (1, L)
281
+ if seq:
282
+ input_ids = torch.tensor(
283
+ [seq],
284
+ device=self.device,
285
+ dtype=torch.long,
286
+ ) # (1, L)
287
+ else:
288
+ input_ids = torch.tensor(
289
+ [[bos_token_id]],
290
+ device=self.device,
291
+ dtype=torch.long,
292
+ ) # (1, 1)
293
+
294
+ token_embeddings = self.gpt2.transformer.wte(input_ids) # (1, L, H)
295
+ inputs_embeds = torch.cat([prefix_embeddings, token_embeddings], dim=1)
296
+
297
+ attention_mask = torch.ones(
298
+ inputs_embeds.size()[:-1],
299
+ dtype=torch.long,
300
+ device=self.device,
301
+ )
302
+
303
+ outputs = self.gpt2(
304
+ inputs_embeds=inputs_embeds,
305
+ attention_mask=attention_mask,
306
+ use_cache=False,
307
+ return_dict=True,
308
+ )
309
+
310
+ logits = outputs.logits[:, -1, :] / max(temperature, 1e-5)
311
+
312
+ if top_k > 0:
313
+ topk_logits, topk_indices = torch.topk(logits, top_k, dim=-1)
314
+ log_probs = torch.log_softmax(topk_logits, dim=-1)
315
+ for i in range(top_k):
316
+ token_id = int(topk_indices[0, i])
317
+ candidate = (seq + [token_id], score + float(log_probs[0, i]))
318
+ all_candidates.append(candidate)
319
+ else:
320
+ log_probs = torch.log_softmax(logits, dim=-1)
321
+ topk_log_probs, topk_indices = torch.topk(log_probs, num_beams, dim=-1)
322
+ for i in range(num_beams):
323
+ token_id = int(topk_indices[0, i])
324
+ candidate = (seq + [token_id], score + float(topk_log_probs[0, i]))
325
+ all_candidates.append(candidate)
326
+
327
+ # Select best beams. With num_beams=1 and length_penalty=0 this
328
+ # reduces to simple greedy decoding, which is fully deterministic.
329
+ beams = sorted(
330
+ all_candidates,
331
+ key=lambda x: _length_normalized_score(x[0], x[1]),
332
+ reverse=True,
333
+ )[:num_beams]
334
+
335
+ # If all beams ended with EOS, stop early
336
+ if all(seq and seq[-1] == eos_token_id for seq, _ in beams):
337
+ break
338
+
339
+ best_seq, best_score = max(
340
+ beams,
341
+ key=lambda x: _length_normalized_score(x[0], x[1]),
342
+ )
343
+
344
+ # Truncate at EOS if present
345
+ if eos_token_id in best_seq:
346
+ best_seq = best_seq[: best_seq.index(eos_token_id)]
347
+
348
+ caption = self.tokenizer.decode(best_seq, skip_special_tokens=True)
349
+ # Normalize whitespace so the final caption is a single, clean string.
350
+ caption = " ".join(caption.strip().split())
351
+ return [caption]
352
+
353
+ # --------------------------------------------------------------------- #
354
+ # Dummy test helper
355
+ # --------------------------------------------------------------------- #
356
+ def test_dummy(self) -> None:
357
+ """
358
+ Run a dummy forward pass to verify the model works end-to-end.
359
+
360
+ This matches the specification in the prompt and asserts that the
361
+ output logits have shape (2, 20, 50257) when captions have length 20.
362
+ """
363
+
364
+ self.eval()
365
+ vocab_size = int(self.gpt2.config.vocab_size)
366
+
367
+ dummy_images = torch.randn(2, 3, 224, 224, device=self.device)
368
+ dummy_captions = torch.randint(0, vocab_size, (2, 20), device=self.device)
369
+
370
+ with torch.no_grad(), contextlib.ExitStack() as stack:
371
+ if self.device.type == "cuda":
372
+ stack.enter_context(torch.cuda.amp.autocast())
373
+
374
+ outputs = self(dummy_images, dummy_captions)
375
+
376
+ logits = outputs.logits
377
+ assert logits.shape == (2, 20, vocab_size), (
378
+ f"Output shape mismatch: expected (2, 20, {vocab_size}), "
379
+ f"got {tuple(logits.shape)}"
380
+ )
381
+ print("✓ Model architecture verified successfully!")
382
+
image_captioning/train.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import math
3
+ import os
4
+ from typing import Tuple
5
+
6
+ import torch
7
+ from torch import Tensor
8
+ from torch.optim import AdamW
9
+ from torch.utils.tensorboard import SummaryWriter
10
+ from tqdm import tqdm
11
+ from transformers import get_cosine_schedule_with_warmup
12
+
13
+ from .config import PathsConfig, TrainingConfig, ensure_dir, get_device, set_seed
14
+ from .dataset import create_dataloader, create_tokenizer
15
+ from .model import ImageCaptioningModel
16
+
17
+
18
+ def parse_args() -> argparse.Namespace:
19
+ """
20
+ Parse command-line arguments for training.
21
+ """
22
+
23
+ parser = argparse.ArgumentParser(description="Train EfficientNetB0 + GPT-2 image captioning model.")
24
+ parser.add_argument("--data_root", type=str, default="/Users/ryan/Downloads/visuallyimpair", help="Root path to dataset.")
25
+ parser.add_argument("--epochs", type=int, default=10, help="Number of training epochs.")
26
+ parser.add_argument("--batch_size", type=int, default=16, help="Batch size.")
27
+ parser.add_argument("--lr", type=float, default=5e-5, help="Learning rate.")
28
+ parser.add_argument("--warmup_steps", type=int, default=500, help="Number of warmup steps.")
29
+ parser.add_argument("--max_length", type=int, default=50, help="Maximum caption length.")
30
+ parser.add_argument("--grad_accum_steps", type=int, default=1, help="Gradient accumulation steps.")
31
+ parser.add_argument("--output_dir", type=str, default="checkpoints", help="Directory to save checkpoints.")
32
+ parser.add_argument("--log_dir", type=str, default="runs", help="Directory for TensorBoard logs.")
33
+ parser.add_argument("--patience", type=int, default=10, help="Early stopping patience based on validation loss.")
34
+ parser.add_argument("--seed", type=int, default=42, help="Random seed.")
35
+ return parser.parse_args()
36
+
37
+
38
+ def create_training_config_from_args(args: argparse.Namespace) -> TrainingConfig:
39
+ """
40
+ Create a TrainingConfig instance using command-line arguments.
41
+ """
42
+
43
+ cfg = TrainingConfig()
44
+ cfg.learning_rate = args.lr
45
+ cfg.batch_size = args.batch_size
46
+ cfg.num_epochs = args.epochs
47
+ cfg.warmup_steps = args.warmup_steps
48
+ cfg.max_caption_length = args.max_length
49
+ cfg.gradient_accumulation_steps = max(1, args.grad_accum_steps)
50
+ cfg.output_dir = args.output_dir
51
+ cfg.log_dir = args.log_dir
52
+ cfg.patience = args.patience
53
+ cfg.seed = args.seed
54
+ return cfg
55
+
56
+
57
+ def validate_dataloader(
58
+ train_loader,
59
+ device: torch.device,
60
+ ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
61
+ """
62
+ Fetch a single batch from the DataLoader to validate dataset loading.
63
+
64
+ Returns
65
+ -------
66
+ Tuple of (images, input_ids, attention_mask, labels).
67
+ """
68
+
69
+ try:
70
+ batch = next(iter(train_loader))
71
+ except StopIteration as exc:
72
+ raise RuntimeError("Training DataLoader is empty. Check your dataset configuration.") from exc
73
+
74
+ images = batch["image"].to(device)
75
+ input_ids = batch["input_ids"].to(device)
76
+ attention_mask = batch["attention_mask"].to(device)
77
+ labels = batch["labels"].to(device)
78
+
79
+ print(f"[DATA] images batch shape: {images.shape}")
80
+ print(f"[DATA] input_ids batch shape: {input_ids.shape}")
81
+ print(f"[DATA] attention_mask batch shape: {attention_mask.shape}")
82
+ print(f"[DATA] labels batch shape: {labels.shape}")
83
+
84
+ return images, input_ids, attention_mask, labels
85
+
86
+
87
+ def train_one_epoch(
88
+ model: ImageCaptioningModel,
89
+ train_loader,
90
+ optimizer: AdamW,
91
+ scheduler,
92
+ device: torch.device,
93
+ cfg: TrainingConfig,
94
+ epoch: int,
95
+ scaler: torch.cuda.amp.GradScaler,
96
+ writer: SummaryWriter,
97
+ ) -> float:
98
+ """
99
+ Train the model for a single epoch.
100
+ """
101
+
102
+ model.train()
103
+ running_loss = 0.0
104
+ num_steps = 0
105
+
106
+ grad_accum_steps = cfg.gradient_accumulation_steps
107
+
108
+ progress = tqdm(train_loader, desc=f"Epoch {epoch} [train]", unit="batch")
109
+ for step, batch in enumerate(progress):
110
+ images = batch["image"].to(device)
111
+ input_ids = batch["input_ids"].to(device)
112
+ attention_mask = batch["attention_mask"].to(device)
113
+ labels = batch["labels"].to(device)
114
+
115
+ with torch.cuda.amp.autocast(enabled=(device.type == "cuda" and cfg.mixed_precision)):
116
+ outputs = model(
117
+ images=images,
118
+ captions=input_ids,
119
+ attention_mask=attention_mask,
120
+ labels=labels,
121
+ )
122
+ loss = outputs.loss
123
+ if loss is None:
124
+ raise RuntimeError("Model did not return a loss during training.")
125
+
126
+ loss = loss / grad_accum_steps
127
+
128
+ scaler.scale(loss).backward()
129
+
130
+ if (step + 1) % grad_accum_steps == 0:
131
+ scaler.unscale_(optimizer)
132
+ torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.max_grad_norm)
133
+ scaler.step(optimizer)
134
+ scaler.update()
135
+ optimizer.zero_grad(set_to_none=True)
136
+ scheduler.step()
137
+
138
+ running_loss += loss.item() * grad_accum_steps
139
+ num_steps += 1
140
+ avg_loss = running_loss / num_steps
141
+ progress.set_postfix({"loss": f"{avg_loss:.4f}"})
142
+
143
+ epoch_loss = running_loss / max(1, num_steps)
144
+ writer.add_scalar("Loss/train", epoch_loss, epoch)
145
+ return epoch_loss
146
+
147
+
148
+ def evaluate(
149
+ model: ImageCaptioningModel,
150
+ val_loader,
151
+ device: torch.device,
152
+ cfg: TrainingConfig,
153
+ epoch: int,
154
+ writer: SummaryWriter,
155
+ ) -> float:
156
+ """
157
+ Evaluate the model on a validation split and return the average loss.
158
+ """
159
+
160
+ model.eval()
161
+ running_loss = 0.0
162
+ num_steps = 0
163
+
164
+ with torch.no_grad():
165
+ progress = tqdm(val_loader, desc=f"Epoch {epoch} [val]", unit="batch")
166
+ for batch in progress:
167
+ images = batch["image"].to(device)
168
+ input_ids = batch["input_ids"].to(device)
169
+ attention_mask = batch["attention_mask"].to(device)
170
+ labels = batch["labels"].to(device)
171
+
172
+ outputs = model(
173
+ images=images,
174
+ captions=input_ids,
175
+ attention_mask=attention_mask,
176
+ labels=labels,
177
+ )
178
+ loss = outputs.loss
179
+ if loss is None:
180
+ raise RuntimeError("Model did not return a loss during validation.")
181
+
182
+ running_loss += loss.item()
183
+ num_steps += 1
184
+ avg_loss = running_loss / num_steps
185
+ progress.set_postfix({"val_loss": f"{avg_loss:.4f}"})
186
+
187
+ val_loss = running_loss / max(1, num_steps)
188
+ writer.add_scalar("Loss/val", val_loss, epoch)
189
+ return val_loss
190
+
191
+
192
+ def main() -> None:
193
+ args = parse_args()
194
+
195
+ # Configuration and setup
196
+ paths_cfg = PathsConfig(data_root=args.data_root)
197
+ training_cfg = create_training_config_from_args(args)
198
+
199
+ ensure_dir(training_cfg.output_dir)
200
+ ensure_dir(training_cfg.log_dir)
201
+
202
+ set_seed(training_cfg.seed)
203
+ device = get_device()
204
+
205
+ # Data
206
+ tokenizer = create_tokenizer()
207
+ train_loader, tokenizer = create_dataloader(
208
+ paths_cfg=paths_cfg,
209
+ training_cfg=training_cfg,
210
+ split="train",
211
+ tokenizer=tokenizer,
212
+ shuffle=True,
213
+ )
214
+ val_loader, _ = create_dataloader(
215
+ paths_cfg=paths_cfg,
216
+ training_cfg=training_cfg,
217
+ split="val",
218
+ tokenizer=tokenizer,
219
+ shuffle=False,
220
+ )
221
+
222
+ # Validate dataset loading
223
+ validate_dataloader(train_loader, device)
224
+
225
+ # Model
226
+ model = ImageCaptioningModel(training_cfg=training_cfg)
227
+
228
+ optimizer = AdamW(model.parameters(), lr=training_cfg.learning_rate)
229
+
230
+ total_training_steps = math.ceil(
231
+ len(train_loader) / max(1, training_cfg.gradient_accumulation_steps)
232
+ ) * training_cfg.num_epochs
233
+
234
+ scheduler = get_cosine_schedule_with_warmup(
235
+ optimizer,
236
+ num_warmup_steps=training_cfg.warmup_steps,
237
+ num_training_steps=total_training_steps,
238
+ )
239
+
240
+ scaler = torch.cuda.amp.GradScaler(enabled=(device.type == "cuda" and training_cfg.mixed_precision))
241
+ writer = SummaryWriter(log_dir=training_cfg.log_dir)
242
+
243
+ best_val_loss = float("inf")
244
+ epochs_without_improvement = 0
245
+
246
+ try:
247
+ for epoch in range(1, training_cfg.num_epochs + 1):
248
+ train_loss = train_one_epoch(
249
+ model=model,
250
+ train_loader=train_loader,
251
+ optimizer=optimizer,
252
+ scheduler=scheduler,
253
+ device=device,
254
+ cfg=training_cfg,
255
+ epoch=epoch,
256
+ scaler=scaler,
257
+ writer=writer,
258
+ )
259
+
260
+ val_loss = evaluate(
261
+ model=model,
262
+ val_loader=val_loader,
263
+ device=device,
264
+ cfg=training_cfg,
265
+ epoch=epoch,
266
+ writer=writer,
267
+ )
268
+
269
+ print(f"[EPOCH {epoch}] train_loss={train_loss:.4f} val_loss={val_loss:.4f}")
270
+
271
+ # Checkpointing
272
+ if val_loss < best_val_loss:
273
+ best_val_loss = val_loss
274
+ epochs_without_improvement = 0
275
+ best_path = os.path.join(training_cfg.output_dir, "best_model.pt")
276
+ torch.save(model.state_dict(), best_path)
277
+ print(f"[CHECKPOINT] Saved new best model to {best_path}")
278
+ else:
279
+ epochs_without_improvement += 1
280
+ print(
281
+ f"[EARLY STOP] No improvement for {epochs_without_improvement} "
282
+ f"epoch(s) (patience={training_cfg.patience})."
283
+ )
284
+
285
+ if epochs_without_improvement >= training_cfg.patience:
286
+ print("Early stopping triggered.")
287
+ break
288
+ except Exception as exc: # noqa: BLE001
289
+ print(f"[ERROR] Training failed with error: {exc}")
290
+ raise
291
+ finally:
292
+ writer.close()
293
+
294
+
295
+ if __name__ == "__main__":
296
+ main()
297
+