Spaces:
Sleeping
Sleeping
| import argparse | |
| import os | |
| from typing import List | |
| import torch | |
| from PIL import Image | |
| from torchvision import transforms | |
| from .config import PathsConfig, TrainingConfig, get_device, set_seed | |
| from .dataset import IMAGENET_MEAN, IMAGENET_STD, create_tokenizer | |
| from .model import ImageCaptioningModel | |
| def parse_args() -> argparse.Namespace: | |
| """ | |
| Parse command-line arguments for inference. | |
| """ | |
| parser = argparse.ArgumentParser(description="Run image captioning inference on a single image.") | |
| parser.add_argument("--image", type=str, required=True, help="Path to image file.") | |
| parser.add_argument("--checkpoint", type=str, default="checkpoints/best_model.pt", help="Path to model checkpoint.") | |
| parser.add_argument("--max_length", type=int, default=50, help="Maximum caption length.") | |
| parser.add_argument("--num_beams", type=int, default=3, help="Number of beams for beam search.") | |
| parser.add_argument("--seed", type=int, default=42, help="Random seed.") | |
| parser.add_argument("--data_root", type=str, default="/Users/ryan/Downloads/visuallyimpair", help="Root path to dataset (for consistency).") | |
| return parser.parse_args() | |
| def build_preprocess_transform() -> transforms.Compose: | |
| """ | |
| Build image preprocessing transform matching the training pipeline. | |
| """ | |
| return transforms.Compose( | |
| [ | |
| transforms.Resize(256), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), | |
| ] | |
| ) | |
| def load_image(image_path: str) -> torch.Tensor: | |
| """ | |
| Load and preprocess a single image. | |
| """ | |
| if not os.path.exists(image_path): | |
| raise FileNotFoundError(f"Image not found: {image_path}") | |
| image = Image.open(image_path).convert("RGB") | |
| transform = build_preprocess_transform() | |
| tensor = transform(image).unsqueeze(0) # (1, 3, 224, 224) | |
| return tensor | |
| def run_inference(args: argparse.Namespace) -> List[str]: | |
| """ | |
| Run caption generation on the specified image and print the result. | |
| """ | |
| set_seed(args.seed) | |
| device = get_device() | |
| _paths_cfg = PathsConfig(data_root=args.data_root) # Included for consistency and future extensions | |
| training_cfg = TrainingConfig(max_caption_length=args.max_length) | |
| tokenizer = create_tokenizer() | |
| model = ImageCaptioningModel(training_cfg=training_cfg) | |
| if not os.path.exists(args.checkpoint): | |
| raise FileNotFoundError(f"Checkpoint not found: {args.checkpoint}") | |
| state_dict = torch.load(args.checkpoint, map_location=device) | |
| model.load_state_dict(state_dict) | |
| model.to(device) | |
| model.eval() | |
| image_tensor = load_image(args.image).to(device) | |
| captions = model.generate( | |
| images=image_tensor, | |
| max_length=args.max_length, | |
| num_beams=args.num_beams, | |
| ) | |
| for caption in captions: | |
| print(f"Caption: {caption}") | |
| return captions | |
| def main() -> None: | |
| args = parse_args() | |
| run_inference(args) | |
| if __name__ == "__main__": | |
| main() | |