Ryanfafa's picture
Upload 7 files
19ea5c5 verified
import argparse
import os
from typing import List
import torch
from PIL import Image
from torchvision import transforms
from .config import PathsConfig, TrainingConfig, get_device, set_seed
from .dataset import IMAGENET_MEAN, IMAGENET_STD, create_tokenizer
from .model import ImageCaptioningModel
def parse_args() -> argparse.Namespace:
"""
Parse command-line arguments for inference.
"""
parser = argparse.ArgumentParser(description="Run image captioning inference on a single image.")
parser.add_argument("--image", type=str, required=True, help="Path to image file.")
parser.add_argument("--checkpoint", type=str, default="checkpoints/best_model.pt", help="Path to model checkpoint.")
parser.add_argument("--max_length", type=int, default=50, help="Maximum caption length.")
parser.add_argument("--num_beams", type=int, default=3, help="Number of beams for beam search.")
parser.add_argument("--seed", type=int, default=42, help="Random seed.")
parser.add_argument("--data_root", type=str, default="/Users/ryan/Downloads/visuallyimpair", help="Root path to dataset (for consistency).")
return parser.parse_args()
def build_preprocess_transform() -> transforms.Compose:
"""
Build image preprocessing transform matching the training pipeline.
"""
return transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
]
)
def load_image(image_path: str) -> torch.Tensor:
"""
Load and preprocess a single image.
"""
if not os.path.exists(image_path):
raise FileNotFoundError(f"Image not found: {image_path}")
image = Image.open(image_path).convert("RGB")
transform = build_preprocess_transform()
tensor = transform(image).unsqueeze(0) # (1, 3, 224, 224)
return tensor
def run_inference(args: argparse.Namespace) -> List[str]:
"""
Run caption generation on the specified image and print the result.
"""
set_seed(args.seed)
device = get_device()
_paths_cfg = PathsConfig(data_root=args.data_root) # Included for consistency and future extensions
training_cfg = TrainingConfig(max_caption_length=args.max_length)
tokenizer = create_tokenizer()
model = ImageCaptioningModel(training_cfg=training_cfg)
if not os.path.exists(args.checkpoint):
raise FileNotFoundError(f"Checkpoint not found: {args.checkpoint}")
state_dict = torch.load(args.checkpoint, map_location=device)
model.load_state_dict(state_dict)
model.to(device)
model.eval()
image_tensor = load_image(args.image).to(device)
captions = model.generate(
images=image_tensor,
max_length=args.max_length,
num_beams=args.num_beams,
)
for caption in captions:
print(f"Caption: {caption}")
return captions
def main() -> None:
args = parse_args()
run_inference(args)
if __name__ == "__main__":
main()