import torch import pickle from PIL import Image import torchvision.transforms as T import os import sys # # Ensure project root is on sys.path # PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__))) # if PROJECT_ROOT not in sys.path: # sys.path.insert(0, PROJECT_ROOT) # # Backward compatibility: vocab.pkl was saved when Vocabulary lived at # # 'data_processing_vocabulary'. Register an alias so pickle can find it # # at its new location 'src.data_processing_vocabulary'. # from src import data_processing_vocabulary # sys.modules["data_processing_vocabulary"] = data_processing_vocabulary from image_captioning_model import ImageCaptioningModel DEVICE = "cpu" class CaptionGenerator: def __init__( self, model_path="best_model.pth", vocab_path="vocab.pkl", use_vit=True ): print("Loading vocab...") print("Building model...") print("Loading state dict...") # Load vocab with open(vocab_path, "rb") as f: self.vocab = pickle.load(f) print(f"Vocab loaded with {len(self.vocab)} words.") # Build model self.model = ImageCaptioningModel( vocab_size=len(self.vocab), pad_id=self.vocab.word2idx[""], use_vit=use_vit ) print("Model built.") # Load weights state_dict = torch.load(model_path, map_location=DEVICE) self.model.load_state_dict(state_dict) print(f"Model weights loaded from {model_path}.") # Eval mode self.model.eval() self.model.to(DEVICE) # Preprocess self.transform = T.Compose([ T.Resize((224, 224)), T.ToTensor(), T.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) ]) def preprocess(self, image): if not isinstance(image, Image.Image): image = Image.open(image).convert("RGB") image = self.transform(image) return image.unsqueeze(0) def generate(self, image, max_len=30, decoding="beam", beam_width=5): """ Generate a caption for the given image. Args: image: PIL Image or file path. max_len: Maximum caption length. decoding: 'greedy' or 'beam'. beam_width: Number of beams (only used when decoding='beam'). Returns: Generated caption string. """ image = self.preprocess(image).to(DEVICE) if decoding == "beam": caption = self.model.predict_caption_beam( image=image, vocab=self.vocab, beam_width=beam_width, max_len=max_len, device=DEVICE ) else: caption = self.model.predict_caption( image=image, vocab=self.vocab, max_len=max_len, device=DEVICE ) return caption