import torch import torch.nn as nn import torchvision.transforms as transforms import torchvision.models as models from PIL import Image import pickle import os import re from collections import Counter from huggingface_hub import hf_hub_download DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") EMBED_DIM = 512 HIDDEN_DIM = 512 MAX_LEN = 25 # ----------------------- # Vocabulary # ----------------------- class Vocabulary: def __init__(self, freq_threshold=5): self.freq_threshold = freq_threshold self.itos = {0: "pad", 1: "startofseq", 2: "endofseq", 3: "unk"} self.stoi = {v: k for k, v in self.itos.items()} self.index = 4 def __len__(self): return len(self.itos) def tokenizer(self, text): text = text.lower() tokens = re.findall(r"\w+", text) return tokens def build_vocabulary(self, sentence_list): frequencies = Counter() for sentence in sentence_list: tokens = self.tokenizer(sentence) frequencies.update(tokens) for word, freq in frequencies.items(): if freq >= self.freq_threshold: self.stoi[word] = self.index self.itos[self.index] = word self.index += 1 def numericalize(self, text): tokens = self.tokenizer(text) numericalized = [] for token in tokens: numericalized.append(self.stoi.get(token, self.stoi["unk"])) return numericalized # ----------------------- # Encoder # ----------------------- class ResNetEncoder(nn.Module): def __init__(self, embed_dim): super().__init__() resnet = models.resnet50(weights=None) modules = list(resnet.children())[:-1] self.resnet = nn.Sequential(*modules) self.fc = nn.Linear(resnet.fc.in_features, embed_dim) self.batch_norm = nn.BatchNorm1d(embed_dim, momentum=0.01) def forward(self, images): with torch.no_grad(): features = self.resnet(images) features = features.view(features.size(0), -1) features = self.fc(features) features = self.batch_norm(features) return features # ----------------------- # Decoder # ----------------------- class DecoderLSTM(nn.Module): def __init__(self, embed_dim, hidden_dim, vocab_size, num_layers=1): super().__init__() self.embedding = nn.Embedding(vocab_size, embed_dim) self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True) self.fc = nn.Linear(hidden_dim, vocab_size) def forward(self, features, captions): captions = captions[:, :-1] emb = self.embedding(captions) features = features.unsqueeze(1) lstm_input = torch.cat((features, emb), dim=1) outputs, _ = self.lstm(lstm_input) logits = self.fc(outputs) return logits # ----------------------- # Caption Model # ----------------------- class ImageCaptioningModel(nn.Module): def __init__(self, encoder, decoder): super().__init__() self.encoder = encoder self.decoder = decoder def forward(self, images, captions): features = self.encoder(images) outputs = self.decoder(features, captions) return outputs # ----------------------- # Caption Generator # ----------------------- def generate_caption(model, image, vocab): model.eval() image = image.unsqueeze(0).to(DEVICE) sentence = [] with torch.no_grad(): features = model.encoder(image) word_idx = vocab.stoi["startofseq"] hidden = None for _ in range(MAX_LEN): word_tensor = torch.tensor([word_idx]).to(DEVICE) emb = model.decoder.embedding(word_tensor) if hidden is None: lstm_input = torch.cat( [features.unsqueeze(1), emb.unsqueeze(1)], dim=1 ) else: lstm_input = emb.unsqueeze(1) output, hidden = model.decoder.lstm(lstm_input, hidden) logits = model.decoder.fc(output[:, -1, :]) predicted = logits.argmax(1).item() token = vocab.itos[predicted] if token == "endofseq": break sentence.append(token) word_idx = predicted return " ".join(sentence) # ----------------------- # Image Transform # ----------------------- transform = transforms.Compose( [ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], ), ] ) # ----------------------- # Load Model Once # ----------------------- script_dir = os.path.dirname(os.path.abspath(__file__)) CHECKPOINT_PATH = hf_hub_download( repo_id="VIKRAM989/image-label", filename="best_checkpoint.pth" ) VOCAB_PATH = os.path.join(script_dir, "vocab.pkl") class CustomUnpickler(pickle.Unpickler): def find_class(self, module, name): if name == "Vocabulary": return Vocabulary return super().find_class(module, name) with open(VOCAB_PATH, "rb") as f: vocab = CustomUnpickler(f).load() vocab_size = len(vocab) encoder = ResNetEncoder(EMBED_DIM) decoder = DecoderLSTM(EMBED_DIM, HIDDEN_DIM, vocab_size) model = ImageCaptioningModel(encoder, decoder).to(DEVICE) checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE) model.load_state_dict(checkpoint["model_state_dict"]) model.eval() # ----------------------- # Public Function for API # ----------------------- def caption_image(pil_image): img = transform(pil_image).to(DEVICE) caption = generate_caption(model, img, vocab) return caption