vit-image-captioning / inference.py
mostafahagali's picture
Upload 2 files
82907d9 verified
import torch
import pickle
from PIL import Image
import torchvision.transforms as T
import os
import sys
# # Ensure project root is on sys.path
# PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__)))
# if PROJECT_ROOT not in sys.path:
# sys.path.insert(0, PROJECT_ROOT)
# # Backward compatibility: vocab.pkl was saved when Vocabulary lived at
# # 'data_processing_vocabulary'. Register an alias so pickle can find it
# # at its new location 'src.data_processing_vocabulary'.
# from src import data_processing_vocabulary
# sys.modules["data_processing_vocabulary"] = data_processing_vocabulary
from image_captioning_model import ImageCaptioningModel
DEVICE = "cpu"
class CaptionGenerator:
def __init__(
self,
model_path="best_model.pth",
vocab_path="vocab.pkl",
use_vit=True
):
print("Loading vocab...")
print("Building model...")
print("Loading state dict...")
# Load vocab
with open(vocab_path, "rb") as f:
self.vocab = pickle.load(f)
print(f"Vocab loaded with {len(self.vocab)} words.")
# Build model
self.model = ImageCaptioningModel(
vocab_size=len(self.vocab),
pad_id=self.vocab.word2idx["<pad>"],
use_vit=use_vit
)
print("Model built.")
# Load weights
state_dict = torch.load(model_path, map_location=DEVICE)
self.model.load_state_dict(state_dict)
print(f"Model weights loaded from {model_path}.")
# Eval mode
self.model.eval()
self.model.to(DEVICE)
# Preprocess
self.transform = T.Compose([
T.Resize((224, 224)),
T.ToTensor(),
T.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
def preprocess(self, image):
if not isinstance(image, Image.Image):
image = Image.open(image).convert("RGB")
image = self.transform(image)
return image.unsqueeze(0)
def generate(self, image, max_len=30, decoding="beam", beam_width=5):
"""
Generate a caption for the given image.
Args:
image: PIL Image or file path.
max_len: Maximum caption length.
decoding: 'greedy' or 'beam'.
beam_width: Number of beams (only used when decoding='beam').
Returns:
Generated caption string.
"""
image = self.preprocess(image).to(DEVICE)
if decoding == "beam":
caption = self.model.predict_caption_beam(
image=image,
vocab=self.vocab,
beam_width=beam_width,
max_len=max_len,
device=DEVICE
)
else:
caption = self.model.predict_caption(
image=image,
vocab=self.vocab,
max_len=max_len,
device=DEVICE
)
return caption