Spaces:
Running
Running
| #!/usr/bin/env python | |
| import argparse | |
| import torch | |
| from PIL import Image | |
| import nltk | |
| nltk.download('punkt', quiet=True) | |
| # Import the necessary components from resnet_train.py | |
| from resnet_train import EncoderCNN, DecoderRNN, visualize_attention, CONFIG, Vocabulary | |
| import resnet_train # To update its global vocab variable | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Generate image caption from a trained model.") | |
| parser.add_argument("--image", type=str, required=True, help="Path to the input image") | |
| parser.add_argument("--checkpoint", type=str, required=True, help="Path to the trained model checkpoint") | |
| args = parser.parse_args() | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Load checkpoint | |
| checkpoint = torch.load(args.checkpoint, map_location=device,weights_only=False) | |
| # Initialize models | |
| encoder = EncoderCNN().to(device) | |
| decoder = DecoderRNN().to(device) | |
| # Load state dictionaries | |
| encoder.load_state_dict(checkpoint['encoder']) | |
| decoder.load_state_dict(checkpoint['decoder']) | |
| # Update the global vocabulary from the checkpoint | |
| resnet_train.vocab = checkpoint['vocab'] | |
| # Generate caption using the provided image path | |
| caption = visualize_attention(args.image, encoder, decoder, device) | |
| print(caption) | |
| if __name__ == "__main__": | |
| main() | |