File size: 3,067 Bytes
19ea5c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import argparse
import os
from typing import List

import torch
from PIL import Image
from torchvision import transforms

from .config import PathsConfig, TrainingConfig, get_device, set_seed
from .dataset import IMAGENET_MEAN, IMAGENET_STD, create_tokenizer
from .model import ImageCaptioningModel


def parse_args() -> argparse.Namespace:
    """
    Parse command-line arguments for inference.
    """

    parser = argparse.ArgumentParser(description="Run image captioning inference on a single image.")
    parser.add_argument("--image", type=str, required=True, help="Path to image file.")
    parser.add_argument("--checkpoint", type=str, default="checkpoints/best_model.pt", help="Path to model checkpoint.")
    parser.add_argument("--max_length", type=int, default=50, help="Maximum caption length.")
    parser.add_argument("--num_beams", type=int, default=3, help="Number of beams for beam search.")
    parser.add_argument("--seed", type=int, default=42, help="Random seed.")
    parser.add_argument("--data_root", type=str, default="/Users/ryan/Downloads/visuallyimpair", help="Root path to dataset (for consistency).")
    return parser.parse_args()


def build_preprocess_transform() -> transforms.Compose:
    """
    Build image preprocessing transform matching the training pipeline.
    """

    return transforms.Compose(
        [
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
        ]
    )


def load_image(image_path: str) -> torch.Tensor:
    """
    Load and preprocess a single image.
    """

    if not os.path.exists(image_path):
        raise FileNotFoundError(f"Image not found: {image_path}")

    image = Image.open(image_path).convert("RGB")
    transform = build_preprocess_transform()
    tensor = transform(image).unsqueeze(0)  # (1, 3, 224, 224)
    return tensor


def run_inference(args: argparse.Namespace) -> List[str]:
    """
    Run caption generation on the specified image and print the result.
    """

    set_seed(args.seed)
    device = get_device()

    _paths_cfg = PathsConfig(data_root=args.data_root)  # Included for consistency and future extensions
    training_cfg = TrainingConfig(max_caption_length=args.max_length)

    tokenizer = create_tokenizer()

    model = ImageCaptioningModel(training_cfg=training_cfg)
    if not os.path.exists(args.checkpoint):
        raise FileNotFoundError(f"Checkpoint not found: {args.checkpoint}")

    state_dict = torch.load(args.checkpoint, map_location=device)
    model.load_state_dict(state_dict)
    model.to(device)
    model.eval()

    image_tensor = load_image(args.image).to(device)

    captions = model.generate(
        images=image_tensor,
        max_length=args.max_length,
        num_beams=args.num_beams,
    )

    for caption in captions:
        print(f"Caption: {caption}")

    return captions


def main() -> None:
    args = parse_args()
    run_inference(args)


if __name__ == "__main__":
    main()