ImageCaptionner / scripts /resnet_caption.py
AOUNZakaria's picture
Deploy image captioner
32d4a86
#!/usr/bin/env python
import argparse
import torch
from PIL import Image
import nltk
nltk.download('punkt', quiet=True)
# Import the necessary components from resnet_train.py
from resnet_train import EncoderCNN, DecoderRNN, visualize_attention, CONFIG, Vocabulary
import resnet_train # To update its global vocab variable
def main():
parser = argparse.ArgumentParser(description="Generate image caption from a trained model.")
parser.add_argument("--image", type=str, required=True, help="Path to the input image")
parser.add_argument("--checkpoint", type=str, required=True, help="Path to the trained model checkpoint")
args = parser.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load checkpoint
checkpoint = torch.load(args.checkpoint, map_location=device,weights_only=False)
# Initialize models
encoder = EncoderCNN().to(device)
decoder = DecoderRNN().to(device)
# Load state dictionaries
encoder.load_state_dict(checkpoint['encoder'])
decoder.load_state_dict(checkpoint['decoder'])
# Update the global vocabulary from the checkpoint
resnet_train.vocab = checkpoint['vocab']
# Generate caption using the provided image path
caption = visualize_attention(args.image, encoder, decoder, device)
print(caption)
if __name__ == "__main__":
main()