File size: 1,369 Bytes
32d4a86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
#!/usr/bin/env python
import argparse
import torch
from PIL import Image
import nltk
nltk.download('punkt', quiet=True)

# Import the necessary components from resnet_train.py
from resnet_train import EncoderCNN, DecoderRNN, visualize_attention, CONFIG, Vocabulary
import resnet_train  # To update its global vocab variable

def main():
    parser = argparse.ArgumentParser(description="Generate image caption from a trained model.")
    parser.add_argument("--image", type=str, required=True, help="Path to the input image")
    parser.add_argument("--checkpoint", type=str, required=True, help="Path to the trained model checkpoint")
    args = parser.parse_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load checkpoint
    checkpoint = torch.load(args.checkpoint, map_location=device,weights_only=False)

    # Initialize models
    encoder = EncoderCNN().to(device)
    decoder = DecoderRNN().to(device)

    # Load state dictionaries
    encoder.load_state_dict(checkpoint['encoder'])
    decoder.load_state_dict(checkpoint['decoder'])

    # Update the global vocabulary from the checkpoint
    resnet_train.vocab = checkpoint['vocab']

    # Generate caption using the provided image path
    caption = visualize_attention(args.image, encoder, decoder, device)
    print(caption)

if __name__ == "__main__":
    main()