| import torch
|
| import pickle
|
| from PIL import Image
|
| import torchvision.transforms as T
|
| import os
|
| import sys
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| from image_captioning_model import ImageCaptioningModel
|
|
|
|
|
| DEVICE = "cpu"
|
|
|
|
|
| class CaptionGenerator:
|
|
|
| def __init__(
|
| self,
|
| model_path="best_model.pth",
|
| vocab_path="vocab.pkl",
|
| use_vit=True
|
| ):
|
|
|
| print("Loading vocab...")
|
| print("Building model...")
|
| print("Loading state dict...")
|
|
|
|
|
|
|
| with open(vocab_path, "rb") as f:
|
| self.vocab = pickle.load(f)
|
|
|
| print(f"Vocab loaded with {len(self.vocab)} words.")
|
|
|
|
|
| self.model = ImageCaptioningModel(
|
| vocab_size=len(self.vocab),
|
| pad_id=self.vocab.word2idx["<pad>"],
|
| use_vit=use_vit
|
| )
|
| print("Model built.")
|
|
|
|
|
| state_dict = torch.load(model_path, map_location=DEVICE)
|
| self.model.load_state_dict(state_dict)
|
| print(f"Model weights loaded from {model_path}.")
|
|
|
|
|
| self.model.eval()
|
| self.model.to(DEVICE)
|
|
|
|
|
| self.transform = T.Compose([
|
| T.Resize((224, 224)),
|
| T.ToTensor(),
|
| T.Normalize(
|
| mean=[0.485, 0.456, 0.406],
|
| std=[0.229, 0.224, 0.225]
|
| )
|
| ])
|
|
|
|
|
| def preprocess(self, image):
|
|
|
| if not isinstance(image, Image.Image):
|
| image = Image.open(image).convert("RGB")
|
|
|
| image = self.transform(image)
|
| return image.unsqueeze(0)
|
|
|
|
|
| def generate(self, image, max_len=30, decoding="beam", beam_width=5):
|
| """
|
| Generate a caption for the given image.
|
|
|
| Args:
|
| image: PIL Image or file path.
|
| max_len: Maximum caption length.
|
| decoding: 'greedy' or 'beam'.
|
| beam_width: Number of beams (only used when decoding='beam').
|
|
|
| Returns:
|
| Generated caption string.
|
| """
|
| image = self.preprocess(image).to(DEVICE)
|
|
|
| if decoding == "beam":
|
| caption = self.model.predict_caption_beam(
|
| image=image,
|
| vocab=self.vocab,
|
| beam_width=beam_width,
|
| max_len=max_len,
|
| device=DEVICE
|
| )
|
| else:
|
| caption = self.model.predict_caption(
|
| image=image,
|
| vocab=self.vocab,
|
| max_len=max_len,
|
| device=DEVICE
|
| )
|
|
|
| return caption
|
|
|