Spaces:
Sleeping
Sleeping
Upload 7 files
Browse files- image_captioning/__init__.py +11 -0
- image_captioning/config.py +112 -0
- image_captioning/dataset.py +385 -0
- image_captioning/evaluate.py +207 -0
- image_captioning/inference.py +101 -0
- image_captioning/model.py +382 -0
- image_captioning/train.py +297 -0
image_captioning/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Image captioning package: EfficientNetB0 encoder + GPT-2 decoder.
|
| 3 |
+
|
| 4 |
+
This package exposes the main components:
|
| 5 |
+
- ImageCaptioningModel (in model.py)
|
| 6 |
+
- dataset/dataloader utilities (in dataset.py)
|
| 7 |
+
- training, evaluation, and inference scripts.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from .model import ImageCaptioningModel # noqa: F401
|
| 11 |
+
|
image_captioning/config.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import random
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Optional
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@dataclass
|
| 11 |
+
class PathsConfig:
|
| 12 |
+
"""
|
| 13 |
+
Configuration for dataset and checkpoint paths.
|
| 14 |
+
|
| 15 |
+
This is tailored to your existing visually impaired dataset layout:
|
| 16 |
+
- Images: <data_root>/visual_dataset/*.jpg
|
| 17 |
+
- Text: <data_root>/visual_text/visual.token.txt
|
| 18 |
+
<data_root>/visual_text/visual.trainImages.txt
|
| 19 |
+
<data_root>/visual_text/visual.testImages.txt
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
data_root: str = "/Users/ryan/Downloads/visuallyimpair"
|
| 23 |
+
images_dir_name: str = "visual_dataset"
|
| 24 |
+
text_dir_name: str = "visual_text"
|
| 25 |
+
|
| 26 |
+
def _join(self, *parts: str) -> str:
|
| 27 |
+
return os.path.join(*parts)
|
| 28 |
+
|
| 29 |
+
@property
|
| 30 |
+
def images_dir(self) -> str:
|
| 31 |
+
return self._join(self.data_root, self.images_dir_name)
|
| 32 |
+
|
| 33 |
+
@property
|
| 34 |
+
def text_dir(self) -> str:
|
| 35 |
+
return self._join(self.data_root, self.text_dir_name)
|
| 36 |
+
|
| 37 |
+
@property
|
| 38 |
+
def token_file(self) -> str:
|
| 39 |
+
return self._join(self.text_dir, "visual.token.txt")
|
| 40 |
+
|
| 41 |
+
@property
|
| 42 |
+
def train_list_file(self) -> str:
|
| 43 |
+
return self._join(self.text_dir, "visual.trainImages.txt")
|
| 44 |
+
|
| 45 |
+
@property
|
| 46 |
+
def test_list_file(self) -> str:
|
| 47 |
+
return self._join(self.text_dir, "visual.testImages.txt")
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
@dataclass
|
| 51 |
+
class TrainingConfig:
|
| 52 |
+
"""
|
| 53 |
+
Hyperparameters and training-related configuration.
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
learning_rate: float = 5e-5
|
| 57 |
+
batch_size: int = 16
|
| 58 |
+
num_epochs: int = 10
|
| 59 |
+
warmup_steps: int = 500
|
| 60 |
+
max_caption_length: int = 50
|
| 61 |
+
gradient_accumulation_steps: int = 1
|
| 62 |
+
num_workers: int = 4
|
| 63 |
+
mixed_precision: bool = True
|
| 64 |
+
patience: int = 3
|
| 65 |
+
max_grad_norm: float = 1.0
|
| 66 |
+
|
| 67 |
+
# Model-specific
|
| 68 |
+
prefix_length: int = 1 # number of visual prefix tokens
|
| 69 |
+
|
| 70 |
+
# Logging / checkpoints
|
| 71 |
+
output_dir: str = "checkpoints"
|
| 72 |
+
log_dir: str = "runs"
|
| 73 |
+
|
| 74 |
+
# Reproducibility
|
| 75 |
+
seed: int = 42
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def get_device() -> torch.device:
|
| 79 |
+
"""
|
| 80 |
+
Return the best available device (CUDA if available, else CPU) and log it.
|
| 81 |
+
"""
|
| 82 |
+
|
| 83 |
+
if torch.cuda.is_available():
|
| 84 |
+
device = torch.device("cuda")
|
| 85 |
+
print("Using CUDA for training/inference.")
|
| 86 |
+
else:
|
| 87 |
+
device = torch.device("cpu")
|
| 88 |
+
print("CUDA not available, falling back to CPU.")
|
| 89 |
+
return device
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def set_seed(seed: int) -> None:
|
| 93 |
+
"""
|
| 94 |
+
Set random seeds for reproducibility across Python, NumPy, and PyTorch.
|
| 95 |
+
"""
|
| 96 |
+
|
| 97 |
+
random.seed(seed)
|
| 98 |
+
np.random.seed(seed)
|
| 99 |
+
torch.manual_seed(seed)
|
| 100 |
+
torch.cuda.manual_seed_all(seed)
|
| 101 |
+
|
| 102 |
+
torch.backends.cudnn.deterministic = True
|
| 103 |
+
torch.backends.cudnn.benchmark = False
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def ensure_dir(path: str) -> None:
|
| 107 |
+
"""
|
| 108 |
+
Create directory if it does not already exist.
|
| 109 |
+
"""
|
| 110 |
+
|
| 111 |
+
os.makedirs(path, exist_ok=True)
|
| 112 |
+
|
image_captioning/dataset.py
ADDED
|
@@ -0,0 +1,385 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import random
|
| 3 |
+
from typing import Dict, List, Optional, Tuple
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from PIL import Image
|
| 7 |
+
from torch import Tensor
|
| 8 |
+
from torch.utils.data import DataLoader, Dataset, Subset
|
| 9 |
+
from torchvision import transforms
|
| 10 |
+
from transformers import GPT2TokenizerFast
|
| 11 |
+
|
| 12 |
+
from .config import PathsConfig, TrainingConfig
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
IMAGENET_MEAN = [0.485, 0.456, 0.406]
|
| 16 |
+
IMAGENET_STD = [0.229, 0.224, 0.225]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def train_image_transform() -> transforms.Compose:
|
| 20 |
+
"""
|
| 21 |
+
Image preprocessing for training with random augmentation to improve
|
| 22 |
+
generalization. Augmentations are kept moderate to avoid changing the
|
| 23 |
+
semantic content of the scene.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
return transforms.Compose(
|
| 27 |
+
[
|
| 28 |
+
transforms.Resize(256),
|
| 29 |
+
transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
|
| 30 |
+
transforms.RandomHorizontalFlip(p=0.5),
|
| 31 |
+
transforms.ColorJitter(
|
| 32 |
+
brightness=0.2,
|
| 33 |
+
contrast=0.2,
|
| 34 |
+
saturation=0.2,
|
| 35 |
+
hue=0.05,
|
| 36 |
+
),
|
| 37 |
+
transforms.ToTensor(),
|
| 38 |
+
transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
|
| 39 |
+
]
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def eval_image_transform() -> transforms.Compose:
|
| 44 |
+
"""
|
| 45 |
+
Deterministic preprocessing for validation and test: resize, center-crop
|
| 46 |
+
to 224x224, normalize.
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
return transforms.Compose(
|
| 50 |
+
[
|
| 51 |
+
transforms.Resize(256),
|
| 52 |
+
transforms.CenterCrop(224),
|
| 53 |
+
transforms.ToTensor(),
|
| 54 |
+
transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
|
| 55 |
+
]
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class ImageCaptionDataset(Dataset):
|
| 60 |
+
"""
|
| 61 |
+
Custom Dataset for the visually impaired image captioning data.
|
| 62 |
+
|
| 63 |
+
This implementation is tailored to your existing layout:
|
| 64 |
+
- Images: <data_root>/visual_dataset/*.jpg
|
| 65 |
+
- Text:
|
| 66 |
+
- visual.token.txt (image#idx<TAB>caption)
|
| 67 |
+
- visual.trainImages.txt (one image filename per line)
|
| 68 |
+
- visual.testImages.txt (one image filename per line)
|
| 69 |
+
"""
|
| 70 |
+
|
| 71 |
+
def __init__(
|
| 72 |
+
self,
|
| 73 |
+
paths_cfg: PathsConfig,
|
| 74 |
+
tokenizer: GPT2TokenizerFast,
|
| 75 |
+
split: str = "train",
|
| 76 |
+
training_cfg: Optional[TrainingConfig] = None,
|
| 77 |
+
transform: Optional[transforms.Compose] = None,
|
| 78 |
+
random_caption: bool = True,
|
| 79 |
+
) -> None:
|
| 80 |
+
super().__init__()
|
| 81 |
+
|
| 82 |
+
if split not in {"train", "val", "test"}:
|
| 83 |
+
raise ValueError("split must be one of {'train', 'val', 'test'}")
|
| 84 |
+
|
| 85 |
+
self.paths_cfg = paths_cfg
|
| 86 |
+
self.tokenizer = tokenizer
|
| 87 |
+
self.training_cfg = training_cfg or TrainingConfig()
|
| 88 |
+
# If no transform is provided, fall back to a deterministic eval
|
| 89 |
+
# transform so this class can still be used directly. In practice,
|
| 90 |
+
# create_dataloader() will supply train/eval-specific transforms.
|
| 91 |
+
self.transform = transform or eval_image_transform()
|
| 92 |
+
self.random_caption = random_caption
|
| 93 |
+
|
| 94 |
+
self.max_length: int = int(self.training_cfg.max_caption_length)
|
| 95 |
+
|
| 96 |
+
# Load all captions from visual.token.txt
|
| 97 |
+
token_path = self.paths_cfg.token_file
|
| 98 |
+
if not os.path.exists(token_path):
|
| 99 |
+
raise FileNotFoundError(f"Caption file not found: {token_path}")
|
| 100 |
+
|
| 101 |
+
self.captions_by_image: Dict[str, List[str]] = {}
|
| 102 |
+
with open(token_path, "r", encoding="utf-8") as f:
|
| 103 |
+
for line in f:
|
| 104 |
+
line = line.strip()
|
| 105 |
+
if not line:
|
| 106 |
+
continue
|
| 107 |
+
try:
|
| 108 |
+
key, caption = line.split("\t", 1)
|
| 109 |
+
except ValueError as exc:
|
| 110 |
+
raise ValueError(f"Malformed line in {token_path}: {line}") from exc
|
| 111 |
+
|
| 112 |
+
img_name = key.split("#")[0]
|
| 113 |
+
self.captions_by_image.setdefault(img_name, []).append(caption.strip())
|
| 114 |
+
|
| 115 |
+
# Choose image list file based on split
|
| 116 |
+
if split == "train":
|
| 117 |
+
list_file = self.paths_cfg.train_list_file
|
| 118 |
+
else:
|
| 119 |
+
# We only have a single test list in this dataset; use it for both
|
| 120 |
+
# 'val' and 'test' splits for now.
|
| 121 |
+
list_file = self.paths_cfg.test_list_file
|
| 122 |
+
|
| 123 |
+
if not os.path.exists(list_file):
|
| 124 |
+
raise FileNotFoundError(f"Image list file for split '{split}' not found: {list_file}")
|
| 125 |
+
|
| 126 |
+
self.image_ids: List[str] = []
|
| 127 |
+
with open(list_file, "r", encoding="utf-8") as f:
|
| 128 |
+
for line in f:
|
| 129 |
+
img_name = line.strip()
|
| 130 |
+
if not img_name:
|
| 131 |
+
continue
|
| 132 |
+
if img_name not in self.captions_by_image:
|
| 133 |
+
# Skip images without captions to avoid runtime issues
|
| 134 |
+
continue
|
| 135 |
+
self.image_ids.append(img_name)
|
| 136 |
+
|
| 137 |
+
if not self.image_ids:
|
| 138 |
+
raise RuntimeError(f"No images with captions found for split '{split}'.")
|
| 139 |
+
|
| 140 |
+
print(f"Loaded {len(self.image_ids)} {split} images with captions.")
|
| 141 |
+
|
| 142 |
+
def __len__(self) -> int:
|
| 143 |
+
return len(self.image_ids)
|
| 144 |
+
|
| 145 |
+
def __getitem__(self, idx: int) -> Dict[str, Tensor]:
|
| 146 |
+
img_name = self.image_ids[idx]
|
| 147 |
+
img_path = os.path.join(self.paths_cfg.images_dir, img_name)
|
| 148 |
+
|
| 149 |
+
if not os.path.exists(img_path):
|
| 150 |
+
raise FileNotFoundError(f"Image file not found: {img_path}")
|
| 151 |
+
|
| 152 |
+
image = Image.open(img_path).convert("RGB")
|
| 153 |
+
image_tensor = self.transform(image)
|
| 154 |
+
|
| 155 |
+
caption_list = self.captions_by_image[img_name]
|
| 156 |
+
if not caption_list:
|
| 157 |
+
raise RuntimeError(f"No captions available for image {img_name}")
|
| 158 |
+
|
| 159 |
+
# Choose a caption. During training we consider up to three different
|
| 160 |
+
# captions per image and randomly sample among them; for evaluation we
|
| 161 |
+
# always take the first caption. We only strip leading/trailing
|
| 162 |
+
# whitespace so that the raw textual content is preserved and no
|
| 163 |
+
# characters are dropped before tokenization.
|
| 164 |
+
if self.random_caption:
|
| 165 |
+
limited_captions = caption_list[:3]
|
| 166 |
+
caption = random.choice(limited_captions)
|
| 167 |
+
else:
|
| 168 |
+
caption = caption_list[0]
|
| 169 |
+
caption = caption.strip()
|
| 170 |
+
|
| 171 |
+
# Convert caption text into token IDs without adding any extra special
|
| 172 |
+
# tokens so we retain a direct mapping between the raw caption string
|
| 173 |
+
# and the token sequence.
|
| 174 |
+
token_ids: List[int] = self.tokenizer.encode(
|
| 175 |
+
caption,
|
| 176 |
+
add_special_tokens=False,
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
# Define explicit BOS (start-of-sentence) and EOS (end-of-sentence)
|
| 180 |
+
# tokens so the model learns where captions begin and end. If the
|
| 181 |
+
# tokenizer does not define a BOS token, we reuse EOS.
|
| 182 |
+
bos_token_id = self.tokenizer.bos_token_id or self.tokenizer.eos_token_id
|
| 183 |
+
eos_token_id = self.tokenizer.eos_token_id
|
| 184 |
+
|
| 185 |
+
seq_ids: List[int] = [bos_token_id] + token_ids + [eos_token_id]
|
| 186 |
+
|
| 187 |
+
# Truncate if necessary to respect max_length. To guarantee that the
|
| 188 |
+
# full caption (including BOS/EOS) can be represented without cutting
|
| 189 |
+
# tokens, ensure that training_cfg.max_caption_length is set large
|
| 190 |
+
# enough for your data.
|
| 191 |
+
if len(seq_ids) > self.max_length:
|
| 192 |
+
seq_ids = seq_ids[: self.max_length]
|
| 193 |
+
|
| 194 |
+
# Pad up to max_length with pad_token_id and build attention mask.
|
| 195 |
+
pad_id = self.tokenizer.pad_token_id
|
| 196 |
+
input_ids = torch.full(
|
| 197 |
+
(self.max_length,),
|
| 198 |
+
pad_id,
|
| 199 |
+
dtype=torch.long,
|
| 200 |
+
)
|
| 201 |
+
attention_mask = torch.zeros(self.max_length, dtype=torch.long)
|
| 202 |
+
|
| 203 |
+
seq_len = len(seq_ids)
|
| 204 |
+
input_ids[:seq_len] = torch.tensor(seq_ids, dtype=torch.long)
|
| 205 |
+
attention_mask[:seq_len] = 1
|
| 206 |
+
|
| 207 |
+
# Labels are initially the same as input_ids; padding positions will
|
| 208 |
+
# be set to -100 so they are ignored by the loss.
|
| 209 |
+
labels = input_ids.clone()
|
| 210 |
+
labels[attention_mask == 0] = -100
|
| 211 |
+
|
| 212 |
+
return {
|
| 213 |
+
"image": image_tensor,
|
| 214 |
+
"input_ids": input_ids,
|
| 215 |
+
"attention_mask": attention_mask,
|
| 216 |
+
"labels": labels,
|
| 217 |
+
"caption": caption,
|
| 218 |
+
"image_id": img_name,
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def create_tokenizer() -> GPT2TokenizerFast:
|
| 223 |
+
"""
|
| 224 |
+
Create a GPT-2 tokenizer with a defined pad token.
|
| 225 |
+
"""
|
| 226 |
+
|
| 227 |
+
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
|
| 228 |
+
if tokenizer.pad_token is None:
|
| 229 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 230 |
+
return tokenizer
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def _infer_category_from_filename(filename: str) -> str:
|
| 234 |
+
"""
|
| 235 |
+
Infer a coarse category label from an image filename.
|
| 236 |
+
|
| 237 |
+
Heuristic:
|
| 238 |
+
- Strip directory and extension.
|
| 239 |
+
- Remove trailing digits to group files like 'bench1.jpg', 'bench25.jpg'
|
| 240 |
+
into the same category 'bench'.
|
| 241 |
+
"""
|
| 242 |
+
|
| 243 |
+
base = os.path.basename(filename)
|
| 244 |
+
stem, _ext = os.path.splitext(base)
|
| 245 |
+
|
| 246 |
+
# Remove trailing digits
|
| 247 |
+
i = len(stem)
|
| 248 |
+
while i > 0 and stem[i - 1].isdigit():
|
| 249 |
+
i -= 1
|
| 250 |
+
category = stem[:i] or stem
|
| 251 |
+
|
| 252 |
+
return category
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def _balanced_train_val_indices(
|
| 256 |
+
dataset: ImageCaptionDataset,
|
| 257 |
+
val_ratio: float = 0.2,
|
| 258 |
+
) -> Tuple[List[int], List[int]]:
|
| 259 |
+
"""
|
| 260 |
+
Split the dataset indices into train and validation sets.
|
| 261 |
+
|
| 262 |
+
The validation set:
|
| 263 |
+
- Targets approximately `val_ratio` of the total dataset size.
|
| 264 |
+
- Is balanced across categories inferred from filenames, i.e., each
|
| 265 |
+
category contributes (as much as possible) the same number of images.
|
| 266 |
+
"""
|
| 267 |
+
|
| 268 |
+
num_items = len(dataset.image_ids)
|
| 269 |
+
if num_items == 0:
|
| 270 |
+
raise RuntimeError("Cannot create train/val split from an empty dataset.")
|
| 271 |
+
|
| 272 |
+
# Group indices by inferred category
|
| 273 |
+
category_to_indices: Dict[str, List[int]] = {}
|
| 274 |
+
for idx, img_name in enumerate(dataset.image_ids):
|
| 275 |
+
cat = _infer_category_from_filename(img_name)
|
| 276 |
+
category_to_indices.setdefault(cat, []).append(idx)
|
| 277 |
+
|
| 278 |
+
# Sort indices within each category for deterministic behavior
|
| 279 |
+
for indices in category_to_indices.values():
|
| 280 |
+
indices.sort()
|
| 281 |
+
|
| 282 |
+
categories = sorted(category_to_indices.keys())
|
| 283 |
+
num_categories = len(categories)
|
| 284 |
+
|
| 285 |
+
# Desired total size for validation set
|
| 286 |
+
target_val_size = max(1, int(round(val_ratio * num_items)))
|
| 287 |
+
|
| 288 |
+
# Base number of validation samples per category, constrained by the
|
| 289 |
+
# smallest category so we can keep counts balanced.
|
| 290 |
+
min_cat_size = min(len(category_to_indices[cat]) for cat in categories)
|
| 291 |
+
per_category = min(
|
| 292 |
+
min_cat_size,
|
| 293 |
+
max(1, int(round(target_val_size / max(1, num_categories)))),
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
val_indices: List[int] = []
|
| 297 |
+
train_indices: List[int] = []
|
| 298 |
+
|
| 299 |
+
for cat in categories:
|
| 300 |
+
indices = category_to_indices[cat]
|
| 301 |
+
val_for_cat = indices[:per_category]
|
| 302 |
+
train_for_cat = indices[per_category:]
|
| 303 |
+
val_indices.extend(val_for_cat)
|
| 304 |
+
train_indices.extend(train_for_cat)
|
| 305 |
+
|
| 306 |
+
return train_indices, val_indices
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def create_dataloader(
|
| 310 |
+
paths_cfg: PathsConfig,
|
| 311 |
+
training_cfg: TrainingConfig,
|
| 312 |
+
split: str,
|
| 313 |
+
tokenizer: Optional[GPT2TokenizerFast] = None,
|
| 314 |
+
shuffle: Optional[bool] = None,
|
| 315 |
+
) -> Tuple[DataLoader, GPT2TokenizerFast]:
|
| 316 |
+
"""
|
| 317 |
+
Factory function to create a DataLoader for a given split.
|
| 318 |
+
|
| 319 |
+
Parameters
|
| 320 |
+
----------
|
| 321 |
+
paths_cfg:
|
| 322 |
+
Paths configuration.
|
| 323 |
+
training_cfg:
|
| 324 |
+
Training configuration containing batch size, max caption length, etc.
|
| 325 |
+
split:
|
| 326 |
+
One of {'train', 'val', 'test'}.
|
| 327 |
+
tokenizer:
|
| 328 |
+
Optional pre-initialized GPT-2 tokenizer. If None, a new one is created.
|
| 329 |
+
shuffle:
|
| 330 |
+
Optional flag to override shuffle behavior. If None, shuffle is True
|
| 331 |
+
for the 'train' split and False otherwise.
|
| 332 |
+
"""
|
| 333 |
+
|
| 334 |
+
if tokenizer is None:
|
| 335 |
+
tokenizer = create_tokenizer()
|
| 336 |
+
|
| 337 |
+
if shuffle is None:
|
| 338 |
+
shuffle = split == "train"
|
| 339 |
+
|
| 340 |
+
# For training and validation, we build a single underlying dataset from
|
| 341 |
+
# the training list file and then create a balanced 80/20 split by
|
| 342 |
+
# category. The test split continues to use the dedicated test list file.
|
| 343 |
+
if split == "test":
|
| 344 |
+
random_caption = False
|
| 345 |
+
dataset = ImageCaptionDataset(
|
| 346 |
+
paths_cfg=paths_cfg,
|
| 347 |
+
tokenizer=tokenizer,
|
| 348 |
+
split="test",
|
| 349 |
+
training_cfg=training_cfg,
|
| 350 |
+
transform=eval_image_transform(),
|
| 351 |
+
random_caption=random_caption,
|
| 352 |
+
)
|
| 353 |
+
else:
|
| 354 |
+
# Underlying full training dataset
|
| 355 |
+
full_train_dataset = ImageCaptionDataset(
|
| 356 |
+
paths_cfg=paths_cfg,
|
| 357 |
+
tokenizer=tokenizer,
|
| 358 |
+
split="train",
|
| 359 |
+
training_cfg=training_cfg,
|
| 360 |
+
transform=train_image_transform(),
|
| 361 |
+
random_caption=True, # always randomize captions during training
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
train_indices, val_indices = _balanced_train_val_indices(
|
| 365 |
+
full_train_dataset,
|
| 366 |
+
val_ratio=0.2,
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
if split == "train":
|
| 370 |
+
dataset = Subset(full_train_dataset, train_indices)
|
| 371 |
+
elif split == "val":
|
| 372 |
+
dataset = Subset(full_train_dataset, val_indices)
|
| 373 |
+
else:
|
| 374 |
+
raise ValueError("split must be one of {'train', 'val', 'test'}")
|
| 375 |
+
|
| 376 |
+
dataloader = DataLoader(
|
| 377 |
+
dataset,
|
| 378 |
+
batch_size=training_cfg.batch_size,
|
| 379 |
+
shuffle=shuffle,
|
| 380 |
+
num_workers=training_cfg.num_workers,
|
| 381 |
+
pin_memory=torch.cuda.is_available(),
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
return dataloader, tokenizer
|
| 385 |
+
|
image_captioning/evaluate.py
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
from typing import Dict, List, Tuple
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from nltk.translate.bleu_score import SmoothingFunction, corpus_bleu
|
| 8 |
+
from nltk.translate.meteor_score import single_meteor_score
|
| 9 |
+
from rouge_score import rouge_scorer
|
| 10 |
+
|
| 11 |
+
from .config import PathsConfig, TrainingConfig, get_device, set_seed
|
| 12 |
+
from .dataset import create_dataloader, create_tokenizer
|
| 13 |
+
from .model import ImageCaptioningModel
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def parse_args() -> argparse.Namespace:
|
| 17 |
+
"""
|
| 18 |
+
Parse command-line arguments for evaluation.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
parser = argparse.ArgumentParser(description="Evaluate image captioning model on test set.")
|
| 22 |
+
parser.add_argument("--data_root", type=str, default="/Users/ryan/Downloads/visuallyimpair", help="Root path to dataset.")
|
| 23 |
+
parser.add_argument("--checkpoint", type=str, required=True, help="Path to model checkpoint (.pt).")
|
| 24 |
+
parser.add_argument("--batch_size", type=int, default=16, help="Batch size for evaluation.")
|
| 25 |
+
parser.add_argument("--max_length", type=int, default=50, help="Maximum caption length during generation.")
|
| 26 |
+
parser.add_argument("--num_beams", type=int, default=3, help="Number of beams for beam search.")
|
| 27 |
+
parser.add_argument("--seed", type=int, default=42, help="Random seed.")
|
| 28 |
+
parser.add_argument("--output_samples", type=str, default="evaluation_samples.jsonl", help="File to save sample predictions.")
|
| 29 |
+
return parser.parse_args()
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def compute_metrics(
|
| 33 |
+
references: List[List[str]],
|
| 34 |
+
hypotheses: List[str],
|
| 35 |
+
) -> Dict[str, float]:
|
| 36 |
+
"""
|
| 37 |
+
Compute BLEU (1-4), METEOR, and ROUGE-L metrics.
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
if not references or not hypotheses:
|
| 41 |
+
raise ValueError("References and hypotheses must be non-empty.")
|
| 42 |
+
if len(references) != len(hypotheses):
|
| 43 |
+
raise ValueError("Number of references and hypotheses must match.")
|
| 44 |
+
|
| 45 |
+
smoothie = SmoothingFunction().method4
|
| 46 |
+
|
| 47 |
+
# BLEU scores
|
| 48 |
+
bleu1 = corpus_bleu(
|
| 49 |
+
references,
|
| 50 |
+
hypotheses,
|
| 51 |
+
weights=(1.0, 0.0, 0.0, 0.0),
|
| 52 |
+
smoothing_function=smoothie,
|
| 53 |
+
)
|
| 54 |
+
bleu2 = corpus_bleu(
|
| 55 |
+
references,
|
| 56 |
+
hypotheses,
|
| 57 |
+
weights=(0.5, 0.5, 0.0, 0.0),
|
| 58 |
+
smoothing_function=smoothie,
|
| 59 |
+
)
|
| 60 |
+
bleu3 = corpus_bleu(
|
| 61 |
+
references,
|
| 62 |
+
hypotheses,
|
| 63 |
+
weights=(1.0 / 3, 1.0 / 3, 1.0 / 3, 0.0),
|
| 64 |
+
smoothing_function=smoothie,
|
| 65 |
+
)
|
| 66 |
+
bleu4 = corpus_bleu(
|
| 67 |
+
references,
|
| 68 |
+
hypotheses,
|
| 69 |
+
weights=(0.25, 0.25, 0.25, 0.25),
|
| 70 |
+
smoothing_function=smoothie,
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
# METEOR
|
| 74 |
+
meteor_scores: List[float] = []
|
| 75 |
+
for ref_list, hyp in zip(references, hypotheses):
|
| 76 |
+
# Use the first reference for METEOR; tokenize by simple whitespace.
|
| 77 |
+
# If NLTK's WordNet data is missing, fall back to a simple unigram F1.
|
| 78 |
+
ref_tokens = ref_list[0].split()
|
| 79 |
+
hyp_tokens = hyp.split()
|
| 80 |
+
try:
|
| 81 |
+
meteor_scores.append(single_meteor_score(ref_tokens, hyp_tokens))
|
| 82 |
+
except LookupError:
|
| 83 |
+
ref_set = set(ref_tokens)
|
| 84 |
+
hyp_set = set(hyp_tokens)
|
| 85 |
+
if not ref_set or not hyp_set:
|
| 86 |
+
meteor_scores.append(0.0)
|
| 87 |
+
else:
|
| 88 |
+
overlap = len(ref_set & hyp_set)
|
| 89 |
+
precision = overlap / len(hyp_set)
|
| 90 |
+
recall = overlap / len(ref_set)
|
| 91 |
+
if precision + recall == 0:
|
| 92 |
+
meteor_scores.append(0.0)
|
| 93 |
+
else:
|
| 94 |
+
meteor_scores.append(2 * precision * recall / (precision + recall))
|
| 95 |
+
meteor = sum(meteor_scores) / max(1, len(meteor_scores))
|
| 96 |
+
|
| 97 |
+
# ROUGE-L
|
| 98 |
+
rouge = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=True)
|
| 99 |
+
rouge_l_scores: List[float] = []
|
| 100 |
+
for ref_list, hyp in zip(references, hypotheses):
|
| 101 |
+
scores = rouge.score(ref_list[0], hyp)
|
| 102 |
+
rouge_l_scores.append(scores["rougeL"].fmeasure)
|
| 103 |
+
rouge_l = sum(rouge_l_scores) / max(1, len(rouge_l_scores))
|
| 104 |
+
|
| 105 |
+
return {
|
| 106 |
+
"BLEU-1": bleu1,
|
| 107 |
+
"BLEU-2": bleu2,
|
| 108 |
+
"BLEU-3": bleu3,
|
| 109 |
+
"BLEU-4": bleu4,
|
| 110 |
+
"METEOR": meteor,
|
| 111 |
+
"ROUGE-L": rouge_l,
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def run_evaluation(args: argparse.Namespace) -> None:
|
| 116 |
+
"""
|
| 117 |
+
Run evaluation on the test set, compute metrics, and save sample predictions.
|
| 118 |
+
"""
|
| 119 |
+
|
| 120 |
+
paths_cfg = PathsConfig(data_root=args.data_root)
|
| 121 |
+
training_cfg = TrainingConfig(
|
| 122 |
+
batch_size=args.batch_size,
|
| 123 |
+
max_caption_length=args.max_length,
|
| 124 |
+
num_epochs=1,
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
set_seed(args.seed)
|
| 128 |
+
device = get_device()
|
| 129 |
+
|
| 130 |
+
tokenizer = create_tokenizer()
|
| 131 |
+
test_loader, tokenizer = create_dataloader(
|
| 132 |
+
paths_cfg=paths_cfg,
|
| 133 |
+
training_cfg=training_cfg,
|
| 134 |
+
split="test",
|
| 135 |
+
tokenizer=tokenizer,
|
| 136 |
+
shuffle=False,
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
model = ImageCaptioningModel(training_cfg=training_cfg)
|
| 140 |
+
state_dict = torch.load(args.checkpoint, map_location=device)
|
| 141 |
+
model.load_state_dict(state_dict)
|
| 142 |
+
model.to(device)
|
| 143 |
+
model.eval()
|
| 144 |
+
|
| 145 |
+
references: List[List[str]] = []
|
| 146 |
+
hypotheses: List[str] = []
|
| 147 |
+
|
| 148 |
+
num_samples_to_save = 50
|
| 149 |
+
saved_samples: List[Dict[str, str]] = []
|
| 150 |
+
|
| 151 |
+
with torch.no_grad():
|
| 152 |
+
for batch in test_loader:
|
| 153 |
+
images = batch["image"].to(device)
|
| 154 |
+
# Use the raw caption string from the dataset as reference
|
| 155 |
+
captions = batch["caption"]
|
| 156 |
+
|
| 157 |
+
# Generate predictions one image at a time to respect generate() constraints
|
| 158 |
+
for idx in range(images.size(0)):
|
| 159 |
+
single_image = images[idx : idx + 1]
|
| 160 |
+
ref_caption = captions[idx]
|
| 161 |
+
|
| 162 |
+
pred_text_list = model.generate(
|
| 163 |
+
images=single_image,
|
| 164 |
+
max_length=args.max_length,
|
| 165 |
+
num_beams=args.num_beams,
|
| 166 |
+
)
|
| 167 |
+
pred_text = pred_text_list[0]
|
| 168 |
+
|
| 169 |
+
references.append([ref_caption])
|
| 170 |
+
hypotheses.append(pred_text)
|
| 171 |
+
|
| 172 |
+
if len(saved_samples) < num_samples_to_save:
|
| 173 |
+
saved_samples.append(
|
| 174 |
+
{
|
| 175 |
+
"image_id": batch["image_id"][idx],
|
| 176 |
+
"reference": ref_caption,
|
| 177 |
+
"prediction": pred_text,
|
| 178 |
+
}
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
metrics = compute_metrics(references, hypotheses)
|
| 182 |
+
|
| 183 |
+
print("Evaluation metrics:")
|
| 184 |
+
for name, value in metrics.items():
|
| 185 |
+
print(f" {name}: {value:.4f}")
|
| 186 |
+
|
| 187 |
+
# Save sample predictions
|
| 188 |
+
output_path = args.output_samples
|
| 189 |
+
with open(output_path, "w", encoding="utf-8") as f:
|
| 190 |
+
for sample in saved_samples:
|
| 191 |
+
f.write(json.dumps(sample) + "\n")
|
| 192 |
+
|
| 193 |
+
print(f"Saved {len(saved_samples)} sample predictions to {output_path}")
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def main() -> None:
|
| 197 |
+
args = parse_args()
|
| 198 |
+
|
| 199 |
+
if not os.path.exists(args.checkpoint):
|
| 200 |
+
raise FileNotFoundError(f"Checkpoint not found: {args.checkpoint}")
|
| 201 |
+
|
| 202 |
+
run_evaluation(args)
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
if __name__ == "__main__":
|
| 206 |
+
main()
|
| 207 |
+
|
image_captioning/inference.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
from typing import List
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from PIL import Image
|
| 7 |
+
from torchvision import transforms
|
| 8 |
+
|
| 9 |
+
from .config import PathsConfig, TrainingConfig, get_device, set_seed
|
| 10 |
+
from .dataset import IMAGENET_MEAN, IMAGENET_STD, create_tokenizer
|
| 11 |
+
from .model import ImageCaptioningModel
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def parse_args() -> argparse.Namespace:
|
| 15 |
+
"""
|
| 16 |
+
Parse command-line arguments for inference.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
parser = argparse.ArgumentParser(description="Run image captioning inference on a single image.")
|
| 20 |
+
parser.add_argument("--image", type=str, required=True, help="Path to image file.")
|
| 21 |
+
parser.add_argument("--checkpoint", type=str, default="checkpoints/best_model.pt", help="Path to model checkpoint.")
|
| 22 |
+
parser.add_argument("--max_length", type=int, default=50, help="Maximum caption length.")
|
| 23 |
+
parser.add_argument("--num_beams", type=int, default=3, help="Number of beams for beam search.")
|
| 24 |
+
parser.add_argument("--seed", type=int, default=42, help="Random seed.")
|
| 25 |
+
parser.add_argument("--data_root", type=str, default="/Users/ryan/Downloads/visuallyimpair", help="Root path to dataset (for consistency).")
|
| 26 |
+
return parser.parse_args()
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def build_preprocess_transform() -> transforms.Compose:
|
| 30 |
+
"""
|
| 31 |
+
Build image preprocessing transform matching the training pipeline.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
return transforms.Compose(
|
| 35 |
+
[
|
| 36 |
+
transforms.Resize(256),
|
| 37 |
+
transforms.CenterCrop(224),
|
| 38 |
+
transforms.ToTensor(),
|
| 39 |
+
transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
|
| 40 |
+
]
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def load_image(image_path: str) -> torch.Tensor:
|
| 45 |
+
"""
|
| 46 |
+
Load and preprocess a single image.
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
if not os.path.exists(image_path):
|
| 50 |
+
raise FileNotFoundError(f"Image not found: {image_path}")
|
| 51 |
+
|
| 52 |
+
image = Image.open(image_path).convert("RGB")
|
| 53 |
+
transform = build_preprocess_transform()
|
| 54 |
+
tensor = transform(image).unsqueeze(0) # (1, 3, 224, 224)
|
| 55 |
+
return tensor
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def run_inference(args: argparse.Namespace) -> List[str]:
|
| 59 |
+
"""
|
| 60 |
+
Run caption generation on the specified image and print the result.
|
| 61 |
+
"""
|
| 62 |
+
|
| 63 |
+
set_seed(args.seed)
|
| 64 |
+
device = get_device()
|
| 65 |
+
|
| 66 |
+
_paths_cfg = PathsConfig(data_root=args.data_root) # Included for consistency and future extensions
|
| 67 |
+
training_cfg = TrainingConfig(max_caption_length=args.max_length)
|
| 68 |
+
|
| 69 |
+
tokenizer = create_tokenizer()
|
| 70 |
+
|
| 71 |
+
model = ImageCaptioningModel(training_cfg=training_cfg)
|
| 72 |
+
if not os.path.exists(args.checkpoint):
|
| 73 |
+
raise FileNotFoundError(f"Checkpoint not found: {args.checkpoint}")
|
| 74 |
+
|
| 75 |
+
state_dict = torch.load(args.checkpoint, map_location=device)
|
| 76 |
+
model.load_state_dict(state_dict)
|
| 77 |
+
model.to(device)
|
| 78 |
+
model.eval()
|
| 79 |
+
|
| 80 |
+
image_tensor = load_image(args.image).to(device)
|
| 81 |
+
|
| 82 |
+
captions = model.generate(
|
| 83 |
+
images=image_tensor,
|
| 84 |
+
max_length=args.max_length,
|
| 85 |
+
num_beams=args.num_beams,
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
for caption in captions:
|
| 89 |
+
print(f"Caption: {caption}")
|
| 90 |
+
|
| 91 |
+
return captions
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def main() -> None:
|
| 95 |
+
args = parse_args()
|
| 96 |
+
run_inference(args)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
if __name__ == "__main__":
|
| 100 |
+
main()
|
| 101 |
+
|
image_captioning/model.py
ADDED
|
@@ -0,0 +1,382 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import contextlib
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from typing import List, Optional, Tuple
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from torch import Tensor
|
| 8 |
+
from torchvision import models
|
| 9 |
+
from transformers import GPT2LMHeadModel, GPT2TokenizerFast
|
| 10 |
+
|
| 11 |
+
from .config import TrainingConfig, get_device
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@dataclass
|
| 15 |
+
class ImageCaptioningOutput:
|
| 16 |
+
"""
|
| 17 |
+
Container for model outputs.
|
| 18 |
+
|
| 19 |
+
Attributes
|
| 20 |
+
----------
|
| 21 |
+
logits:
|
| 22 |
+
Predicted token logits of shape (batch_size, seq_len, vocab_size),
|
| 23 |
+
where seq_len is the number of text tokens (visual prefix tokens are removed).
|
| 24 |
+
loss:
|
| 25 |
+
Optional cross-entropy loss over caption tokens.
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
logits: Tensor
|
| 29 |
+
loss: Optional[Tensor] = None
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class EfficientNetB0Encoder(nn.Module):
|
| 33 |
+
"""
|
| 34 |
+
EfficientNet-B0 image encoder using torchvision.
|
| 35 |
+
|
| 36 |
+
The classification head is removed and only the pooled feature vector
|
| 37 |
+
(dimension 1280) is returned.
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
def __init__(self, pretrained: bool = True) -> None:
|
| 41 |
+
super().__init__()
|
| 42 |
+
effnet = models.efficientnet_b0(pretrained=pretrained)
|
| 43 |
+
self.features = effnet.features
|
| 44 |
+
self.avgpool = effnet.avgpool
|
| 45 |
+
self.flatten = nn.Flatten()
|
| 46 |
+
# in_features of the final classifier is the encoder output dim
|
| 47 |
+
self.out_dim: int = effnet.classifier[1].in_features
|
| 48 |
+
|
| 49 |
+
def forward(self, images: Tensor) -> Tensor:
|
| 50 |
+
"""
|
| 51 |
+
Encode a batch of images into a pooled feature representation.
|
| 52 |
+
|
| 53 |
+
Parameters
|
| 54 |
+
----------
|
| 55 |
+
images:
|
| 56 |
+
Tensor of shape (batch_size, 3, 224, 224).
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
x = self.features(images)
|
| 60 |
+
x = self.avgpool(x)
|
| 61 |
+
x = self.flatten(x) # (batch_size, out_dim)
|
| 62 |
+
return x
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class ImageCaptioningModel(nn.Module):
|
| 66 |
+
"""
|
| 67 |
+
Image captioning model with an EfficientNet-B0 vision encoder and GPT-2 decoder.
|
| 68 |
+
|
| 69 |
+
The model projects visual features into a sequence of prefix embeddings that
|
| 70 |
+
are concatenated with GPT-2 token embeddings. GPT-2 then predicts caption tokens.
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
def __init__(
|
| 74 |
+
self,
|
| 75 |
+
training_cfg: Optional[TrainingConfig] = None,
|
| 76 |
+
pretrained_encoder: bool = True,
|
| 77 |
+
) -> None:
|
| 78 |
+
super().__init__()
|
| 79 |
+
|
| 80 |
+
self.training_cfg = training_cfg or TrainingConfig()
|
| 81 |
+
self.device: torch.device = get_device()
|
| 82 |
+
|
| 83 |
+
# Vision encoder
|
| 84 |
+
self.encoder = EfficientNetB0Encoder(pretrained=pretrained_encoder)
|
| 85 |
+
|
| 86 |
+
# Text decoder (GPT-2 small)
|
| 87 |
+
self.tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
|
| 88 |
+
if self.tokenizer.pad_token is None:
|
| 89 |
+
# Use EOS as pad token
|
| 90 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 91 |
+
|
| 92 |
+
self.gpt2 = GPT2LMHeadModel.from_pretrained("gpt2")
|
| 93 |
+
self.gpt2.config.pad_token_id = self.tokenizer.pad_token_id
|
| 94 |
+
|
| 95 |
+
# Number of visual prefix tokens
|
| 96 |
+
self.prefix_length: int = int(self.training_cfg.prefix_length)
|
| 97 |
+
if self.prefix_length < 1:
|
| 98 |
+
raise ValueError("prefix_length must be >= 1")
|
| 99 |
+
|
| 100 |
+
# Project image features to a sequence of prefix token embeddings
|
| 101 |
+
self.visual_projection = nn.Linear(
|
| 102 |
+
self.encoder.out_dim,
|
| 103 |
+
self.gpt2.config.n_embd * self.prefix_length,
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
self._printed_debug: bool = False
|
| 107 |
+
|
| 108 |
+
self.to(self.device)
|
| 109 |
+
|
| 110 |
+
# --------------------------------------------------------------------- #
|
| 111 |
+
# Internal utilities
|
| 112 |
+
# --------------------------------------------------------------------- #
|
| 113 |
+
def encode_images(self, images: Tensor) -> Tensor:
|
| 114 |
+
"""
|
| 115 |
+
Encode images and produce visual prefix embeddings.
|
| 116 |
+
|
| 117 |
+
Returns
|
| 118 |
+
-------
|
| 119 |
+
Tensor of shape (batch_size, prefix_length, hidden_size).
|
| 120 |
+
"""
|
| 121 |
+
|
| 122 |
+
assert images.dim() == 4, f"Expected images of shape (B,3,H,W), got {images.shape}"
|
| 123 |
+
img_features = self.encoder(images) # (B, encoder_out_dim)
|
| 124 |
+
batch_size = img_features.size(0)
|
| 125 |
+
|
| 126 |
+
prefix_embeddings = self.visual_projection(img_features)
|
| 127 |
+
prefix_embeddings = prefix_embeddings.view(
|
| 128 |
+
batch_size,
|
| 129 |
+
self.prefix_length,
|
| 130 |
+
self.gpt2.config.n_embd,
|
| 131 |
+
)
|
| 132 |
+
return prefix_embeddings
|
| 133 |
+
|
| 134 |
+
# --------------------------------------------------------------------- #
|
| 135 |
+
# Forward (training)
|
| 136 |
+
# --------------------------------------------------------------------- #
|
| 137 |
+
def forward(
|
| 138 |
+
self,
|
| 139 |
+
images: Tensor,
|
| 140 |
+
captions: Tensor,
|
| 141 |
+
attention_mask: Optional[Tensor] = None,
|
| 142 |
+
labels: Optional[Tensor] = None,
|
| 143 |
+
) -> ImageCaptioningOutput:
|
| 144 |
+
"""
|
| 145 |
+
Forward pass for training.
|
| 146 |
+
|
| 147 |
+
Parameters
|
| 148 |
+
----------
|
| 149 |
+
images:
|
| 150 |
+
Tensor of shape (batch_size, 3, 224, 224).
|
| 151 |
+
captions:
|
| 152 |
+
Token IDs of shape (batch_size, seq_len).
|
| 153 |
+
attention_mask:
|
| 154 |
+
Optional attention mask of shape (batch_size, seq_len).
|
| 155 |
+
labels:
|
| 156 |
+
Optional target token IDs of shape (batch_size, seq_len).
|
| 157 |
+
If provided, cross-entropy loss is computed, ignoring positions
|
| 158 |
+
with label -100.
|
| 159 |
+
"""
|
| 160 |
+
|
| 161 |
+
images = images.to(self.device)
|
| 162 |
+
captions = captions.to(self.device)
|
| 163 |
+
if attention_mask is not None:
|
| 164 |
+
attention_mask = attention_mask.to(self.device)
|
| 165 |
+
if labels is not None:
|
| 166 |
+
labels = labels.to(self.device)
|
| 167 |
+
|
| 168 |
+
batch_size, seq_len = captions.shape
|
| 169 |
+
assert images.size(0) == batch_size, "Batch size mismatch between images and captions."
|
| 170 |
+
|
| 171 |
+
prefix_embeddings = self.encode_images(images) # (B, P, H)
|
| 172 |
+
|
| 173 |
+
token_embeddings = self.gpt2.transformer.wte(captions) # (B, T, H)
|
| 174 |
+
inputs_embeds = torch.cat([prefix_embeddings, token_embeddings], dim=1) # (B, P+T, H)
|
| 175 |
+
|
| 176 |
+
if attention_mask is not None:
|
| 177 |
+
prefix_mask = torch.ones(
|
| 178 |
+
batch_size,
|
| 179 |
+
self.prefix_length,
|
| 180 |
+
dtype=attention_mask.dtype,
|
| 181 |
+
device=attention_mask.device,
|
| 182 |
+
)
|
| 183 |
+
extended_attention_mask = torch.cat([prefix_mask, attention_mask], dim=1)
|
| 184 |
+
else:
|
| 185 |
+
extended_attention_mask = None
|
| 186 |
+
|
| 187 |
+
if not self._printed_debug:
|
| 188 |
+
print(f"[DEBUG] images shape: {images.shape}")
|
| 189 |
+
print(f"[DEBUG] captions shape: {captions.shape}")
|
| 190 |
+
print(f"[DEBUG] prefix_embeddings: {prefix_embeddings.shape}")
|
| 191 |
+
print(f"[DEBUG] token_embeddings: {token_embeddings.shape}")
|
| 192 |
+
print(f"[DEBUG] inputs_embeds shape: {inputs_embeds.shape}")
|
| 193 |
+
if extended_attention_mask is not None:
|
| 194 |
+
print(f"[DEBUG] attention_mask shape: {extended_attention_mask.shape}")
|
| 195 |
+
self._printed_debug = True
|
| 196 |
+
|
| 197 |
+
outputs = self.gpt2(
|
| 198 |
+
inputs_embeds=inputs_embeds,
|
| 199 |
+
attention_mask=extended_attention_mask,
|
| 200 |
+
use_cache=False,
|
| 201 |
+
return_dict=True,
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
# Remove visual prefix positions from the logits so that
|
| 205 |
+
# the returned logits only correspond to text tokens.
|
| 206 |
+
logits = outputs.logits[:, self.prefix_length :, :] # (B, T, V)
|
| 207 |
+
|
| 208 |
+
loss: Optional[Tensor] = None
|
| 209 |
+
if labels is not None:
|
| 210 |
+
if labels.shape != (batch_size, seq_len):
|
| 211 |
+
raise ValueError(
|
| 212 |
+
f"labels shape {labels.shape} does not match captions shape {(batch_size, seq_len)}"
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
# Shift logits and labels for next-token prediction
|
| 216 |
+
shift_logits = logits[:, :-1, :].contiguous()
|
| 217 |
+
shift_labels = labels[:, 1:].contiguous()
|
| 218 |
+
|
| 219 |
+
loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
|
| 220 |
+
loss = loss_fct(
|
| 221 |
+
shift_logits.view(-1, shift_logits.size(-1)),
|
| 222 |
+
shift_labels.view(-1),
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
return ImageCaptioningOutput(logits=logits, loss=loss)
|
| 226 |
+
|
| 227 |
+
# --------------------------------------------------------------------- #
|
| 228 |
+
# Generation (inference)
|
| 229 |
+
# --------------------------------------------------------------------- #
|
| 230 |
+
@torch.no_grad()
|
| 231 |
+
def generate(
|
| 232 |
+
self,
|
| 233 |
+
images: Tensor,
|
| 234 |
+
max_length: int = 50,
|
| 235 |
+
num_beams: int = 1,
|
| 236 |
+
temperature: float = 1.0,
|
| 237 |
+
top_k: int = 0,
|
| 238 |
+
eos_token_id: Optional[int] = None,
|
| 239 |
+
length_penalty: float = 0.0,
|
| 240 |
+
repetition_penalty: float = 1.0,
|
| 241 |
+
) -> List[str]:
|
| 242 |
+
"""
|
| 243 |
+
Generate captions for a batch of images using a simple beam search.
|
| 244 |
+
|
| 245 |
+
Notes
|
| 246 |
+
-----
|
| 247 |
+
- For simplicity and clarity, this implementation currently supports
|
| 248 |
+
batch_size == 1. A ValueError is raised otherwise.
|
| 249 |
+
"""
|
| 250 |
+
|
| 251 |
+
self.eval()
|
| 252 |
+
|
| 253 |
+
images = images.to(self.device)
|
| 254 |
+
batch_size = images.size(0)
|
| 255 |
+
if batch_size != 1:
|
| 256 |
+
raise ValueError(f"generate currently supports batch_size == 1, got {batch_size}")
|
| 257 |
+
|
| 258 |
+
eos_token_id = eos_token_id or self.tokenizer.eos_token_id
|
| 259 |
+
bos_token_id = self.tokenizer.bos_token_id or self.tokenizer.eos_token_id
|
| 260 |
+
|
| 261 |
+
prefix_embeddings = self.encode_images(images) # (1, P, H)
|
| 262 |
+
|
| 263 |
+
# Each beam is (token_ids, log_prob)
|
| 264 |
+
beams: List[Tuple[List[int], float]] = [([], 0.0)]
|
| 265 |
+
|
| 266 |
+
def _length_normalized_score(tokens: List[int], score: float) -> float:
|
| 267 |
+
if length_penalty is None or length_penalty == 0.0:
|
| 268 |
+
return score
|
| 269 |
+
length = max(1, len(tokens))
|
| 270 |
+
return score / (length ** length_penalty)
|
| 271 |
+
|
| 272 |
+
for _ in range(max_length):
|
| 273 |
+
all_candidates: List[Tuple[List[int], float]] = []
|
| 274 |
+
for seq, score in beams:
|
| 275 |
+
if seq and seq[-1] == eos_token_id:
|
| 276 |
+
# If already finished, keep as-is
|
| 277 |
+
all_candidates.append((seq, score))
|
| 278 |
+
continue
|
| 279 |
+
|
| 280 |
+
# Build a 2D tensor of token IDs with shape (1, L)
|
| 281 |
+
if seq:
|
| 282 |
+
input_ids = torch.tensor(
|
| 283 |
+
[seq],
|
| 284 |
+
device=self.device,
|
| 285 |
+
dtype=torch.long,
|
| 286 |
+
) # (1, L)
|
| 287 |
+
else:
|
| 288 |
+
input_ids = torch.tensor(
|
| 289 |
+
[[bos_token_id]],
|
| 290 |
+
device=self.device,
|
| 291 |
+
dtype=torch.long,
|
| 292 |
+
) # (1, 1)
|
| 293 |
+
|
| 294 |
+
token_embeddings = self.gpt2.transformer.wte(input_ids) # (1, L, H)
|
| 295 |
+
inputs_embeds = torch.cat([prefix_embeddings, token_embeddings], dim=1)
|
| 296 |
+
|
| 297 |
+
attention_mask = torch.ones(
|
| 298 |
+
inputs_embeds.size()[:-1],
|
| 299 |
+
dtype=torch.long,
|
| 300 |
+
device=self.device,
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
outputs = self.gpt2(
|
| 304 |
+
inputs_embeds=inputs_embeds,
|
| 305 |
+
attention_mask=attention_mask,
|
| 306 |
+
use_cache=False,
|
| 307 |
+
return_dict=True,
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
logits = outputs.logits[:, -1, :] / max(temperature, 1e-5)
|
| 311 |
+
|
| 312 |
+
if top_k > 0:
|
| 313 |
+
topk_logits, topk_indices = torch.topk(logits, top_k, dim=-1)
|
| 314 |
+
log_probs = torch.log_softmax(topk_logits, dim=-1)
|
| 315 |
+
for i in range(top_k):
|
| 316 |
+
token_id = int(topk_indices[0, i])
|
| 317 |
+
candidate = (seq + [token_id], score + float(log_probs[0, i]))
|
| 318 |
+
all_candidates.append(candidate)
|
| 319 |
+
else:
|
| 320 |
+
log_probs = torch.log_softmax(logits, dim=-1)
|
| 321 |
+
topk_log_probs, topk_indices = torch.topk(log_probs, num_beams, dim=-1)
|
| 322 |
+
for i in range(num_beams):
|
| 323 |
+
token_id = int(topk_indices[0, i])
|
| 324 |
+
candidate = (seq + [token_id], score + float(topk_log_probs[0, i]))
|
| 325 |
+
all_candidates.append(candidate)
|
| 326 |
+
|
| 327 |
+
# Select best beams. With num_beams=1 and length_penalty=0 this
|
| 328 |
+
# reduces to simple greedy decoding, which is fully deterministic.
|
| 329 |
+
beams = sorted(
|
| 330 |
+
all_candidates,
|
| 331 |
+
key=lambda x: _length_normalized_score(x[0], x[1]),
|
| 332 |
+
reverse=True,
|
| 333 |
+
)[:num_beams]
|
| 334 |
+
|
| 335 |
+
# If all beams ended with EOS, stop early
|
| 336 |
+
if all(seq and seq[-1] == eos_token_id for seq, _ in beams):
|
| 337 |
+
break
|
| 338 |
+
|
| 339 |
+
best_seq, best_score = max(
|
| 340 |
+
beams,
|
| 341 |
+
key=lambda x: _length_normalized_score(x[0], x[1]),
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
# Truncate at EOS if present
|
| 345 |
+
if eos_token_id in best_seq:
|
| 346 |
+
best_seq = best_seq[: best_seq.index(eos_token_id)]
|
| 347 |
+
|
| 348 |
+
caption = self.tokenizer.decode(best_seq, skip_special_tokens=True)
|
| 349 |
+
# Normalize whitespace so the final caption is a single, clean string.
|
| 350 |
+
caption = " ".join(caption.strip().split())
|
| 351 |
+
return [caption]
|
| 352 |
+
|
| 353 |
+
# --------------------------------------------------------------------- #
|
| 354 |
+
# Dummy test helper
|
| 355 |
+
# --------------------------------------------------------------------- #
|
| 356 |
+
def test_dummy(self) -> None:
|
| 357 |
+
"""
|
| 358 |
+
Run a dummy forward pass to verify the model works end-to-end.
|
| 359 |
+
|
| 360 |
+
This matches the specification in the prompt and asserts that the
|
| 361 |
+
output logits have shape (2, 20, 50257) when captions have length 20.
|
| 362 |
+
"""
|
| 363 |
+
|
| 364 |
+
self.eval()
|
| 365 |
+
vocab_size = int(self.gpt2.config.vocab_size)
|
| 366 |
+
|
| 367 |
+
dummy_images = torch.randn(2, 3, 224, 224, device=self.device)
|
| 368 |
+
dummy_captions = torch.randint(0, vocab_size, (2, 20), device=self.device)
|
| 369 |
+
|
| 370 |
+
with torch.no_grad(), contextlib.ExitStack() as stack:
|
| 371 |
+
if self.device.type == "cuda":
|
| 372 |
+
stack.enter_context(torch.cuda.amp.autocast())
|
| 373 |
+
|
| 374 |
+
outputs = self(dummy_images, dummy_captions)
|
| 375 |
+
|
| 376 |
+
logits = outputs.logits
|
| 377 |
+
assert logits.shape == (2, 20, vocab_size), (
|
| 378 |
+
f"Output shape mismatch: expected (2, 20, {vocab_size}), "
|
| 379 |
+
f"got {tuple(logits.shape)}"
|
| 380 |
+
)
|
| 381 |
+
print("✓ Model architecture verified successfully!")
|
| 382 |
+
|
image_captioning/train.py
ADDED
|
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import math
|
| 3 |
+
import os
|
| 4 |
+
from typing import Tuple
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch import Tensor
|
| 8 |
+
from torch.optim import AdamW
|
| 9 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
from transformers import get_cosine_schedule_with_warmup
|
| 12 |
+
|
| 13 |
+
from .config import PathsConfig, TrainingConfig, ensure_dir, get_device, set_seed
|
| 14 |
+
from .dataset import create_dataloader, create_tokenizer
|
| 15 |
+
from .model import ImageCaptioningModel
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def parse_args() -> argparse.Namespace:
|
| 19 |
+
"""
|
| 20 |
+
Parse command-line arguments for training.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
parser = argparse.ArgumentParser(description="Train EfficientNetB0 + GPT-2 image captioning model.")
|
| 24 |
+
parser.add_argument("--data_root", type=str, default="/Users/ryan/Downloads/visuallyimpair", help="Root path to dataset.")
|
| 25 |
+
parser.add_argument("--epochs", type=int, default=10, help="Number of training epochs.")
|
| 26 |
+
parser.add_argument("--batch_size", type=int, default=16, help="Batch size.")
|
| 27 |
+
parser.add_argument("--lr", type=float, default=5e-5, help="Learning rate.")
|
| 28 |
+
parser.add_argument("--warmup_steps", type=int, default=500, help="Number of warmup steps.")
|
| 29 |
+
parser.add_argument("--max_length", type=int, default=50, help="Maximum caption length.")
|
| 30 |
+
parser.add_argument("--grad_accum_steps", type=int, default=1, help="Gradient accumulation steps.")
|
| 31 |
+
parser.add_argument("--output_dir", type=str, default="checkpoints", help="Directory to save checkpoints.")
|
| 32 |
+
parser.add_argument("--log_dir", type=str, default="runs", help="Directory for TensorBoard logs.")
|
| 33 |
+
parser.add_argument("--patience", type=int, default=10, help="Early stopping patience based on validation loss.")
|
| 34 |
+
parser.add_argument("--seed", type=int, default=42, help="Random seed.")
|
| 35 |
+
return parser.parse_args()
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def create_training_config_from_args(args: argparse.Namespace) -> TrainingConfig:
|
| 39 |
+
"""
|
| 40 |
+
Create a TrainingConfig instance using command-line arguments.
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
cfg = TrainingConfig()
|
| 44 |
+
cfg.learning_rate = args.lr
|
| 45 |
+
cfg.batch_size = args.batch_size
|
| 46 |
+
cfg.num_epochs = args.epochs
|
| 47 |
+
cfg.warmup_steps = args.warmup_steps
|
| 48 |
+
cfg.max_caption_length = args.max_length
|
| 49 |
+
cfg.gradient_accumulation_steps = max(1, args.grad_accum_steps)
|
| 50 |
+
cfg.output_dir = args.output_dir
|
| 51 |
+
cfg.log_dir = args.log_dir
|
| 52 |
+
cfg.patience = args.patience
|
| 53 |
+
cfg.seed = args.seed
|
| 54 |
+
return cfg
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def validate_dataloader(
|
| 58 |
+
train_loader,
|
| 59 |
+
device: torch.device,
|
| 60 |
+
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
|
| 61 |
+
"""
|
| 62 |
+
Fetch a single batch from the DataLoader to validate dataset loading.
|
| 63 |
+
|
| 64 |
+
Returns
|
| 65 |
+
-------
|
| 66 |
+
Tuple of (images, input_ids, attention_mask, labels).
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
try:
|
| 70 |
+
batch = next(iter(train_loader))
|
| 71 |
+
except StopIteration as exc:
|
| 72 |
+
raise RuntimeError("Training DataLoader is empty. Check your dataset configuration.") from exc
|
| 73 |
+
|
| 74 |
+
images = batch["image"].to(device)
|
| 75 |
+
input_ids = batch["input_ids"].to(device)
|
| 76 |
+
attention_mask = batch["attention_mask"].to(device)
|
| 77 |
+
labels = batch["labels"].to(device)
|
| 78 |
+
|
| 79 |
+
print(f"[DATA] images batch shape: {images.shape}")
|
| 80 |
+
print(f"[DATA] input_ids batch shape: {input_ids.shape}")
|
| 81 |
+
print(f"[DATA] attention_mask batch shape: {attention_mask.shape}")
|
| 82 |
+
print(f"[DATA] labels batch shape: {labels.shape}")
|
| 83 |
+
|
| 84 |
+
return images, input_ids, attention_mask, labels
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def train_one_epoch(
|
| 88 |
+
model: ImageCaptioningModel,
|
| 89 |
+
train_loader,
|
| 90 |
+
optimizer: AdamW,
|
| 91 |
+
scheduler,
|
| 92 |
+
device: torch.device,
|
| 93 |
+
cfg: TrainingConfig,
|
| 94 |
+
epoch: int,
|
| 95 |
+
scaler: torch.cuda.amp.GradScaler,
|
| 96 |
+
writer: SummaryWriter,
|
| 97 |
+
) -> float:
|
| 98 |
+
"""
|
| 99 |
+
Train the model for a single epoch.
|
| 100 |
+
"""
|
| 101 |
+
|
| 102 |
+
model.train()
|
| 103 |
+
running_loss = 0.0
|
| 104 |
+
num_steps = 0
|
| 105 |
+
|
| 106 |
+
grad_accum_steps = cfg.gradient_accumulation_steps
|
| 107 |
+
|
| 108 |
+
progress = tqdm(train_loader, desc=f"Epoch {epoch} [train]", unit="batch")
|
| 109 |
+
for step, batch in enumerate(progress):
|
| 110 |
+
images = batch["image"].to(device)
|
| 111 |
+
input_ids = batch["input_ids"].to(device)
|
| 112 |
+
attention_mask = batch["attention_mask"].to(device)
|
| 113 |
+
labels = batch["labels"].to(device)
|
| 114 |
+
|
| 115 |
+
with torch.cuda.amp.autocast(enabled=(device.type == "cuda" and cfg.mixed_precision)):
|
| 116 |
+
outputs = model(
|
| 117 |
+
images=images,
|
| 118 |
+
captions=input_ids,
|
| 119 |
+
attention_mask=attention_mask,
|
| 120 |
+
labels=labels,
|
| 121 |
+
)
|
| 122 |
+
loss = outputs.loss
|
| 123 |
+
if loss is None:
|
| 124 |
+
raise RuntimeError("Model did not return a loss during training.")
|
| 125 |
+
|
| 126 |
+
loss = loss / grad_accum_steps
|
| 127 |
+
|
| 128 |
+
scaler.scale(loss).backward()
|
| 129 |
+
|
| 130 |
+
if (step + 1) % grad_accum_steps == 0:
|
| 131 |
+
scaler.unscale_(optimizer)
|
| 132 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.max_grad_norm)
|
| 133 |
+
scaler.step(optimizer)
|
| 134 |
+
scaler.update()
|
| 135 |
+
optimizer.zero_grad(set_to_none=True)
|
| 136 |
+
scheduler.step()
|
| 137 |
+
|
| 138 |
+
running_loss += loss.item() * grad_accum_steps
|
| 139 |
+
num_steps += 1
|
| 140 |
+
avg_loss = running_loss / num_steps
|
| 141 |
+
progress.set_postfix({"loss": f"{avg_loss:.4f}"})
|
| 142 |
+
|
| 143 |
+
epoch_loss = running_loss / max(1, num_steps)
|
| 144 |
+
writer.add_scalar("Loss/train", epoch_loss, epoch)
|
| 145 |
+
return epoch_loss
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def evaluate(
|
| 149 |
+
model: ImageCaptioningModel,
|
| 150 |
+
val_loader,
|
| 151 |
+
device: torch.device,
|
| 152 |
+
cfg: TrainingConfig,
|
| 153 |
+
epoch: int,
|
| 154 |
+
writer: SummaryWriter,
|
| 155 |
+
) -> float:
|
| 156 |
+
"""
|
| 157 |
+
Evaluate the model on a validation split and return the average loss.
|
| 158 |
+
"""
|
| 159 |
+
|
| 160 |
+
model.eval()
|
| 161 |
+
running_loss = 0.0
|
| 162 |
+
num_steps = 0
|
| 163 |
+
|
| 164 |
+
with torch.no_grad():
|
| 165 |
+
progress = tqdm(val_loader, desc=f"Epoch {epoch} [val]", unit="batch")
|
| 166 |
+
for batch in progress:
|
| 167 |
+
images = batch["image"].to(device)
|
| 168 |
+
input_ids = batch["input_ids"].to(device)
|
| 169 |
+
attention_mask = batch["attention_mask"].to(device)
|
| 170 |
+
labels = batch["labels"].to(device)
|
| 171 |
+
|
| 172 |
+
outputs = model(
|
| 173 |
+
images=images,
|
| 174 |
+
captions=input_ids,
|
| 175 |
+
attention_mask=attention_mask,
|
| 176 |
+
labels=labels,
|
| 177 |
+
)
|
| 178 |
+
loss = outputs.loss
|
| 179 |
+
if loss is None:
|
| 180 |
+
raise RuntimeError("Model did not return a loss during validation.")
|
| 181 |
+
|
| 182 |
+
running_loss += loss.item()
|
| 183 |
+
num_steps += 1
|
| 184 |
+
avg_loss = running_loss / num_steps
|
| 185 |
+
progress.set_postfix({"val_loss": f"{avg_loss:.4f}"})
|
| 186 |
+
|
| 187 |
+
val_loss = running_loss / max(1, num_steps)
|
| 188 |
+
writer.add_scalar("Loss/val", val_loss, epoch)
|
| 189 |
+
return val_loss
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def main() -> None:
|
| 193 |
+
args = parse_args()
|
| 194 |
+
|
| 195 |
+
# Configuration and setup
|
| 196 |
+
paths_cfg = PathsConfig(data_root=args.data_root)
|
| 197 |
+
training_cfg = create_training_config_from_args(args)
|
| 198 |
+
|
| 199 |
+
ensure_dir(training_cfg.output_dir)
|
| 200 |
+
ensure_dir(training_cfg.log_dir)
|
| 201 |
+
|
| 202 |
+
set_seed(training_cfg.seed)
|
| 203 |
+
device = get_device()
|
| 204 |
+
|
| 205 |
+
# Data
|
| 206 |
+
tokenizer = create_tokenizer()
|
| 207 |
+
train_loader, tokenizer = create_dataloader(
|
| 208 |
+
paths_cfg=paths_cfg,
|
| 209 |
+
training_cfg=training_cfg,
|
| 210 |
+
split="train",
|
| 211 |
+
tokenizer=tokenizer,
|
| 212 |
+
shuffle=True,
|
| 213 |
+
)
|
| 214 |
+
val_loader, _ = create_dataloader(
|
| 215 |
+
paths_cfg=paths_cfg,
|
| 216 |
+
training_cfg=training_cfg,
|
| 217 |
+
split="val",
|
| 218 |
+
tokenizer=tokenizer,
|
| 219 |
+
shuffle=False,
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
# Validate dataset loading
|
| 223 |
+
validate_dataloader(train_loader, device)
|
| 224 |
+
|
| 225 |
+
# Model
|
| 226 |
+
model = ImageCaptioningModel(training_cfg=training_cfg)
|
| 227 |
+
|
| 228 |
+
optimizer = AdamW(model.parameters(), lr=training_cfg.learning_rate)
|
| 229 |
+
|
| 230 |
+
total_training_steps = math.ceil(
|
| 231 |
+
len(train_loader) / max(1, training_cfg.gradient_accumulation_steps)
|
| 232 |
+
) * training_cfg.num_epochs
|
| 233 |
+
|
| 234 |
+
scheduler = get_cosine_schedule_with_warmup(
|
| 235 |
+
optimizer,
|
| 236 |
+
num_warmup_steps=training_cfg.warmup_steps,
|
| 237 |
+
num_training_steps=total_training_steps,
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
scaler = torch.cuda.amp.GradScaler(enabled=(device.type == "cuda" and training_cfg.mixed_precision))
|
| 241 |
+
writer = SummaryWriter(log_dir=training_cfg.log_dir)
|
| 242 |
+
|
| 243 |
+
best_val_loss = float("inf")
|
| 244 |
+
epochs_without_improvement = 0
|
| 245 |
+
|
| 246 |
+
try:
|
| 247 |
+
for epoch in range(1, training_cfg.num_epochs + 1):
|
| 248 |
+
train_loss = train_one_epoch(
|
| 249 |
+
model=model,
|
| 250 |
+
train_loader=train_loader,
|
| 251 |
+
optimizer=optimizer,
|
| 252 |
+
scheduler=scheduler,
|
| 253 |
+
device=device,
|
| 254 |
+
cfg=training_cfg,
|
| 255 |
+
epoch=epoch,
|
| 256 |
+
scaler=scaler,
|
| 257 |
+
writer=writer,
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
val_loss = evaluate(
|
| 261 |
+
model=model,
|
| 262 |
+
val_loader=val_loader,
|
| 263 |
+
device=device,
|
| 264 |
+
cfg=training_cfg,
|
| 265 |
+
epoch=epoch,
|
| 266 |
+
writer=writer,
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
print(f"[EPOCH {epoch}] train_loss={train_loss:.4f} val_loss={val_loss:.4f}")
|
| 270 |
+
|
| 271 |
+
# Checkpointing
|
| 272 |
+
if val_loss < best_val_loss:
|
| 273 |
+
best_val_loss = val_loss
|
| 274 |
+
epochs_without_improvement = 0
|
| 275 |
+
best_path = os.path.join(training_cfg.output_dir, "best_model.pt")
|
| 276 |
+
torch.save(model.state_dict(), best_path)
|
| 277 |
+
print(f"[CHECKPOINT] Saved new best model to {best_path}")
|
| 278 |
+
else:
|
| 279 |
+
epochs_without_improvement += 1
|
| 280 |
+
print(
|
| 281 |
+
f"[EARLY STOP] No improvement for {epochs_without_improvement} "
|
| 282 |
+
f"epoch(s) (patience={training_cfg.patience})."
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
if epochs_without_improvement >= training_cfg.patience:
|
| 286 |
+
print("Early stopping triggered.")
|
| 287 |
+
break
|
| 288 |
+
except Exception as exc: # noqa: BLE001
|
| 289 |
+
print(f"[ERROR] Training failed with error: {exc}")
|
| 290 |
+
raise
|
| 291 |
+
finally:
|
| 292 |
+
writer.close()
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
if __name__ == "__main__":
|
| 296 |
+
main()
|
| 297 |
+
|