import os, torch, pickle, re from io import BytesIO from torchvision import models, transforms from matplotlib import pyplot as plt from torch import nn from collections import Counter from PIL import Image import gradio as gr from huggingface_hub import snapshot_download EMBED_DIM = 256 HIDDEN_DIM = 512 MAX_SEQ_LENGTH = 25 VOCAB_SIZE = 8492 DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") transform_inference = transforms.Compose([ transforms.Resize((384, 384)), transforms.ToTensor(), transforms.Normalize( mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5] ) ]) class Vocabulary: def __init__(self, freq_threshold=5): self.freq_threshold = freq_threshold # self.itos = {0: "", 1: "", 2: "", 3: ""} self.itos = {0: "pad", 1: "startofseq", 2: "endofseq", 3: "unk"} self.stoi = {v: k for k, v in self.itos.items()} self.index = 4 def __len__(self): return len(self.itos) def tokenizer(self, text): text = text.lower() tokens = re.findall(r"\w+", text) return tokens def build_vocabulary(self, sentence_list): frequencies = Counter() for sentence in sentence_list: tokens = self.tokenizer(sentence) frequencies.update(tokens) for word, freq in frequencies.items(): if freq >= self.freq_threshold: self.stoi[word] = self.index self.itos[self.index] = word self.index += 1 def numericalize(self, text): tokens = self.tokenizer(text) numericalized = [] for token in tokens: if token in self.stoi: numericalized.append(self.stoi[token]) else: numericalized.append(self.stoi[""]) return numericalized class ViTEncoder(nn.Module): def __init__(self, embed_dim): super().__init__() # Load pretrained ViT weights = models.ViT_B_16_Weights.IMAGENET1K_SWAG_E2E_V1 # High-quality pretrained weights vit = models.vit_b_16(weights=weights) # Remove classification head self.vit = vit self.vit.heads = nn.Identity() # Optional: fine-tune ViT for param in self.vit.parameters(): param.requires_grad = False # Set to False if you want to freeze the encoder # Projection to embedding dim for decoder self.fc = nn.Linear(self.vit.hidden_dim, embed_dim) self.batch_norm = nn.BatchNorm1d(embed_dim, momentum=0.01) def forward(self, images): # images: (B, 3, H, W) features = self.vit(images) # (B, vit.hidden_dim) features = self.fc(features) # (B, embed_dim) features = self.batch_norm(features) return features class DecoderLSTM(nn.Module): def __init__(self, embed_dim, hidden_dim, vocab_size, num_layers=1): super().__init__() self.embedding = nn.Embedding(vocab_size, embed_dim) self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True) self.fc = nn.Linear(hidden_dim, vocab_size) self.vocab_size = vocab_size def forward(self, features, captions, states): embeddings = self.embedding(captions) inputs = torch.cat((features.unsqueeze(1), embeddings), dim=1) lstm_out, states = self.lstm(inputs, states) logits = self.fc(lstm_out) return logits, states def generate(self, features, max_len=20): # changed batch_size = features.size(0) states = None generated_captions = [] start_idx = 1 # startofseq end_idx = 2 # endofseq current_tokens = [start_idx] for _ in range(max_len): input_tokens = torch.LongTensor(current_tokens).to(features.device).unsqueeze(0) logits, states = self.forward(features, input_tokens, states) logits = logits.contiguous().view(-1, VOCAB_SIZE) predicted = logits.argmax(dim=1)[-1].item() generated_captions.append(predicted) current_tokens.append(predicted) return generated_captions class ImageCaptioningModel(nn.Module): def __init__(self, encoder, decoder): super().__init__() self.encoder = encoder self.decoder = decoder def generate(self, images, max_len=MAX_SEQ_LENGTH): # changed features = self.encoder(images) return self.decoder.generate(features, max_len=max_len) def load_model_and_vocab(repo_id): download_dir = snapshot_download(repo_id) print(download_dir) model_path = os.path.join(download_dir, "best_finetuned_infer.pth") vocab_path = os.path.join(download_dir, "vocab.pkl") encoder = ViTEncoder(embed_dim=EMBED_DIM) decoder = DecoderLSTM(EMBED_DIM, HIDDEN_DIM, VOCAB_SIZE) model = ImageCaptioningModel(encoder, decoder).to(DEVICE) state_dict = torch.load(model_path, map_location=DEVICE) model.load_state_dict(state_dict['model_state_dict']) model.eval() with open(vocab_path, 'rb') as f: vocab = pickle.load(f) return model, vocab model, vocab = load_model_and_vocab("prakhartrivedi/ImageCaptioningSCCCI") print("Model and vocabulary loaded successfully.") def generate_caption_for_image(img): pil_img = img.convert("RGB") img_tensor = transform_inference(pil_img).unsqueeze(0).to(DEVICE) with torch.no_grad(): output_indices = model.generate(img_tensor, max_len=MAX_SEQ_LENGTH) result_words = [] end_token_idx = vocab.stoi["endofseq"] for idx in output_indices: if idx == end_token_idx: break word = vocab.itos.get(idx, "unk") if word not in ["startofseq", "pad", "endofseq"]: result_words.append(word) cap = " ".join(result_words) # Convert tensor (1, 3, H, W) to (H, W, 3) and detach from graph image_np = img_tensor.squeeze(0).permute(1, 2, 0).cpu().numpy() image_np = (image_np * 0.5 + 0.5).clip(0, 1) # unnormalize # Plot the image and caption plt.figure(figsize=(5, 5)) plt.imshow(image_np) plt.axis("off") plt.title(cap) # Save the plot to a buffer buf = BytesIO() plt.savefig(buf, format='png', bbox_inches='tight') plt.close() buf.seek(0) # Convert buffer to PIL image pil_img = Image.open(buf) return pil_img gr.Interface( fn=generate_caption_for_image, inputs=gr.Image(type="pil"), outputs=gr.Image(type="pil") ).launch(share=True)