import argparse import os from typing import List import torch from PIL import Image from torchvision import transforms from .config import PathsConfig, TrainingConfig, get_device, set_seed from .dataset import IMAGENET_MEAN, IMAGENET_STD, create_tokenizer from .model import ImageCaptioningModel def parse_args() -> argparse.Namespace: """ Parse command-line arguments for inference. """ parser = argparse.ArgumentParser(description="Run image captioning inference on a single image.") parser.add_argument("--image", type=str, required=True, help="Path to image file.") parser.add_argument("--checkpoint", type=str, default="checkpoints/best_model.pt", help="Path to model checkpoint.") parser.add_argument("--max_length", type=int, default=50, help="Maximum caption length.") parser.add_argument("--num_beams", type=int, default=3, help="Number of beams for beam search.") parser.add_argument("--seed", type=int, default=42, help="Random seed.") parser.add_argument("--data_root", type=str, default="/Users/ryan/Downloads/visuallyimpair", help="Root path to dataset (for consistency).") return parser.parse_args() def build_preprocess_transform() -> transforms.Compose: """ Build image preprocessing transform matching the training pipeline. """ return transforms.Compose( [ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), ] ) def load_image(image_path: str) -> torch.Tensor: """ Load and preprocess a single image. """ if not os.path.exists(image_path): raise FileNotFoundError(f"Image not found: {image_path}") image = Image.open(image_path).convert("RGB") transform = build_preprocess_transform() tensor = transform(image).unsqueeze(0) # (1, 3, 224, 224) return tensor def run_inference(args: argparse.Namespace) -> List[str]: """ Run caption generation on the specified image and print the result. """ set_seed(args.seed) device = get_device() _paths_cfg = PathsConfig(data_root=args.data_root) # Included for consistency and future extensions training_cfg = TrainingConfig(max_caption_length=args.max_length) tokenizer = create_tokenizer() model = ImageCaptioningModel(training_cfg=training_cfg) if not os.path.exists(args.checkpoint): raise FileNotFoundError(f"Checkpoint not found: {args.checkpoint}") state_dict = torch.load(args.checkpoint, map_location=device) model.load_state_dict(state_dict) model.to(device) model.eval() image_tensor = load_image(args.image).to(device) captions = model.generate( images=image_tensor, max_length=args.max_length, num_beams=args.num_beams, ) for caption in captions: print(f"Caption: {caption}") return captions def main() -> None: args = parse_args() run_inference(args) if __name__ == "__main__": main()