Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torchvision.transforms as transforms | |
| import torchvision.models as models | |
| from PIL import Image | |
| import pickle | |
| import os | |
| import re | |
| from collections import Counter | |
| from huggingface_hub import hf_hub_download | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| EMBED_DIM = 512 | |
| HIDDEN_DIM = 512 | |
| MAX_LEN = 25 | |
| # ----------------------- | |
| # Vocabulary | |
| # ----------------------- | |
| class Vocabulary: | |
| def __init__(self, freq_threshold=5): | |
| self.freq_threshold = freq_threshold | |
| self.itos = {0: "pad", 1: "startofseq", 2: "endofseq", 3: "unk"} | |
| self.stoi = {v: k for k, v in self.itos.items()} | |
| self.index = 4 | |
| def __len__(self): | |
| return len(self.itos) | |
| def tokenizer(self, text): | |
| text = text.lower() | |
| tokens = re.findall(r"\w+", text) | |
| return tokens | |
| def build_vocabulary(self, sentence_list): | |
| frequencies = Counter() | |
| for sentence in sentence_list: | |
| tokens = self.tokenizer(sentence) | |
| frequencies.update(tokens) | |
| for word, freq in frequencies.items(): | |
| if freq >= self.freq_threshold: | |
| self.stoi[word] = self.index | |
| self.itos[self.index] = word | |
| self.index += 1 | |
| def numericalize(self, text): | |
| tokens = self.tokenizer(text) | |
| numericalized = [] | |
| for token in tokens: | |
| numericalized.append(self.stoi.get(token, self.stoi["unk"])) | |
| return numericalized | |
| # ----------------------- | |
| # Encoder | |
| # ----------------------- | |
| class ResNetEncoder(nn.Module): | |
| def __init__(self, embed_dim): | |
| super().__init__() | |
| resnet = models.resnet50(weights=None) | |
| modules = list(resnet.children())[:-1] | |
| self.resnet = nn.Sequential(*modules) | |
| self.fc = nn.Linear(resnet.fc.in_features, embed_dim) | |
| self.batch_norm = nn.BatchNorm1d(embed_dim, momentum=0.01) | |
| def forward(self, images): | |
| with torch.no_grad(): | |
| features = self.resnet(images) | |
| features = features.view(features.size(0), -1) | |
| features = self.fc(features) | |
| features = self.batch_norm(features) | |
| return features | |
| # ----------------------- | |
| # Decoder | |
| # ----------------------- | |
| class DecoderLSTM(nn.Module): | |
| def __init__(self, embed_dim, hidden_dim, vocab_size, num_layers=1): | |
| super().__init__() | |
| self.embedding = nn.Embedding(vocab_size, embed_dim) | |
| self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True) | |
| self.fc = nn.Linear(hidden_dim, vocab_size) | |
| def forward(self, features, captions): | |
| captions = captions[:, :-1] | |
| emb = self.embedding(captions) | |
| features = features.unsqueeze(1) | |
| lstm_input = torch.cat((features, emb), dim=1) | |
| outputs, _ = self.lstm(lstm_input) | |
| logits = self.fc(outputs) | |
| return logits | |
| # ----------------------- | |
| # Caption Model | |
| # ----------------------- | |
| class ImageCaptioningModel(nn.Module): | |
| def __init__(self, encoder, decoder): | |
| super().__init__() | |
| self.encoder = encoder | |
| self.decoder = decoder | |
| def forward(self, images, captions): | |
| features = self.encoder(images) | |
| outputs = self.decoder(features, captions) | |
| return outputs | |
| # ----------------------- | |
| # Caption Generator | |
| # ----------------------- | |
| def generate_caption(model, image, vocab): | |
| model.eval() | |
| image = image.unsqueeze(0).to(DEVICE) | |
| sentence = [] | |
| with torch.no_grad(): | |
| features = model.encoder(image) | |
| word_idx = vocab.stoi["startofseq"] | |
| hidden = None | |
| for _ in range(MAX_LEN): | |
| word_tensor = torch.tensor([word_idx]).to(DEVICE) | |
| emb = model.decoder.embedding(word_tensor) | |
| if hidden is None: | |
| lstm_input = torch.cat( | |
| [features.unsqueeze(1), emb.unsqueeze(1)], dim=1 | |
| ) | |
| else: | |
| lstm_input = emb.unsqueeze(1) | |
| output, hidden = model.decoder.lstm(lstm_input, hidden) | |
| logits = model.decoder.fc(output[:, -1, :]) | |
| predicted = logits.argmax(1).item() | |
| token = vocab.itos[predicted] | |
| if token == "endofseq": | |
| break | |
| sentence.append(token) | |
| word_idx = predicted | |
| return " ".join(sentence) | |
| # ----------------------- | |
| # Image Transform | |
| # ----------------------- | |
| transform = transforms.Compose( | |
| [ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize( | |
| mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225], | |
| ), | |
| ] | |
| ) | |
| # ----------------------- | |
| # Load Model Once | |
| # ----------------------- | |
| script_dir = os.path.dirname(os.path.abspath(__file__)) | |
| CHECKPOINT_PATH = hf_hub_download( | |
| repo_id="VIKRAM989/image-label", | |
| filename="best_checkpoint.pth" | |
| ) | |
| VOCAB_PATH = os.path.join(script_dir, "vocab.pkl") | |
| class CustomUnpickler(pickle.Unpickler): | |
| def find_class(self, module, name): | |
| if name == "Vocabulary": | |
| return Vocabulary | |
| return super().find_class(module, name) | |
| with open(VOCAB_PATH, "rb") as f: | |
| vocab = CustomUnpickler(f).load() | |
| vocab_size = len(vocab) | |
| encoder = ResNetEncoder(EMBED_DIM) | |
| decoder = DecoderLSTM(EMBED_DIM, HIDDEN_DIM, vocab_size) | |
| model = ImageCaptioningModel(encoder, decoder).to(DEVICE) | |
| checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE) | |
| model.load_state_dict(checkpoint["model_state_dict"]) | |
| model.eval() | |
| # ----------------------- | |
| # Public Function for API | |
| # ----------------------- | |
| def caption_image(pil_image): | |
| img = transform(pil_image).to(DEVICE) | |
| caption = generate_caption(model, img, vocab) | |
| return caption |