Spaces:
Sleeping
Sleeping
| import os, torch, pickle, re | |
| from io import BytesIO | |
| from torchvision import models, transforms | |
| from matplotlib import pyplot as plt | |
| from torch import nn | |
| from collections import Counter | |
| from PIL import Image | |
| import gradio as gr | |
| from huggingface_hub import snapshot_download | |
| EMBED_DIM = 256 | |
| HIDDEN_DIM = 512 | |
| MAX_SEQ_LENGTH = 25 | |
| VOCAB_SIZE = 8492 | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| transform_inference = transforms.Compose([ | |
| transforms.Resize((384, 384)), | |
| transforms.ToTensor(), | |
| transforms.Normalize( | |
| mean=[0.5, 0.5, 0.5], | |
| std=[0.5, 0.5, 0.5] | |
| ) | |
| ]) | |
| class Vocabulary: | |
| def __init__(self, freq_threshold=5): | |
| self.freq_threshold = freq_threshold | |
| # self.itos = {0: "<pad>", 1: "<start>", 2: "<end>", 3: "<unk>"} | |
| self.itos = {0: "pad", 1: "startofseq", 2: "endofseq", 3: "unk"} | |
| self.stoi = {v: k for k, v in self.itos.items()} | |
| self.index = 4 | |
| def __len__(self): | |
| return len(self.itos) | |
| def tokenizer(self, text): | |
| text = text.lower() | |
| tokens = re.findall(r"\w+", text) | |
| return tokens | |
| def build_vocabulary(self, sentence_list): | |
| frequencies = Counter() | |
| for sentence in sentence_list: | |
| tokens = self.tokenizer(sentence) | |
| frequencies.update(tokens) | |
| for word, freq in frequencies.items(): | |
| if freq >= self.freq_threshold: | |
| self.stoi[word] = self.index | |
| self.itos[self.index] = word | |
| self.index += 1 | |
| def numericalize(self, text): | |
| tokens = self.tokenizer(text) | |
| numericalized = [] | |
| for token in tokens: | |
| if token in self.stoi: | |
| numericalized.append(self.stoi[token]) | |
| else: | |
| numericalized.append(self.stoi["<unk>"]) | |
| return numericalized | |
| class ViTEncoder(nn.Module): | |
| def __init__(self, embed_dim): | |
| super().__init__() | |
| # Load pretrained ViT | |
| weights = models.ViT_B_16_Weights.IMAGENET1K_SWAG_E2E_V1 # High-quality pretrained weights | |
| vit = models.vit_b_16(weights=weights) | |
| # Remove classification head | |
| self.vit = vit | |
| self.vit.heads = nn.Identity() | |
| # Optional: fine-tune ViT | |
| for param in self.vit.parameters(): | |
| param.requires_grad = False # Set to False if you want to freeze the encoder | |
| # Projection to embedding dim for decoder | |
| self.fc = nn.Linear(self.vit.hidden_dim, embed_dim) | |
| self.batch_norm = nn.BatchNorm1d(embed_dim, momentum=0.01) | |
| def forward(self, images): | |
| # images: (B, 3, H, W) | |
| features = self.vit(images) # (B, vit.hidden_dim) | |
| features = self.fc(features) # (B, embed_dim) | |
| features = self.batch_norm(features) | |
| return features | |
| class DecoderLSTM(nn.Module): | |
| def __init__(self, embed_dim, hidden_dim, vocab_size, num_layers=1): | |
| super().__init__() | |
| self.embedding = nn.Embedding(vocab_size, embed_dim) | |
| self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True) | |
| self.fc = nn.Linear(hidden_dim, vocab_size) | |
| self.vocab_size = vocab_size | |
| def forward(self, features, captions, states): | |
| embeddings = self.embedding(captions) | |
| inputs = torch.cat((features.unsqueeze(1), embeddings), dim=1) | |
| lstm_out, states = self.lstm(inputs, states) | |
| logits = self.fc(lstm_out) | |
| return logits, states | |
| def generate(self, features, max_len=20): # changed | |
| batch_size = features.size(0) | |
| states = None | |
| generated_captions = [] | |
| start_idx = 1 # startofseq | |
| end_idx = 2 # endofseq | |
| current_tokens = [start_idx] | |
| for _ in range(max_len): | |
| input_tokens = torch.LongTensor(current_tokens).to(features.device).unsqueeze(0) | |
| logits, states = self.forward(features, input_tokens, states) | |
| logits = logits.contiguous().view(-1, VOCAB_SIZE) | |
| predicted = logits.argmax(dim=1)[-1].item() | |
| generated_captions.append(predicted) | |
| current_tokens.append(predicted) | |
| return generated_captions | |
| class ImageCaptioningModel(nn.Module): | |
| def __init__(self, encoder, decoder): | |
| super().__init__() | |
| self.encoder = encoder | |
| self.decoder = decoder | |
| def generate(self, images, max_len=MAX_SEQ_LENGTH): # changed | |
| features = self.encoder(images) | |
| return self.decoder.generate(features, max_len=max_len) | |
| def load_model_and_vocab(repo_id): | |
| download_dir = snapshot_download(repo_id) | |
| print(download_dir) | |
| model_path = os.path.join(download_dir, "best_finetuned_infer.pth") | |
| vocab_path = os.path.join(download_dir, "vocab.pkl") | |
| encoder = ViTEncoder(embed_dim=EMBED_DIM) | |
| decoder = DecoderLSTM(EMBED_DIM, HIDDEN_DIM, VOCAB_SIZE) | |
| model = ImageCaptioningModel(encoder, decoder).to(DEVICE) | |
| state_dict = torch.load(model_path, map_location=DEVICE) | |
| model.load_state_dict(state_dict['model_state_dict']) | |
| model.eval() | |
| with open(vocab_path, 'rb') as f: | |
| vocab = pickle.load(f) | |
| return model, vocab | |
| model, vocab = load_model_and_vocab("prakhartrivedi/ImageCaptioningSCCCI") | |
| print("Model and vocabulary loaded successfully.") | |
| def generate_caption_for_image(img): | |
| pil_img = img.convert("RGB") | |
| img_tensor = transform_inference(pil_img).unsqueeze(0).to(DEVICE) | |
| with torch.no_grad(): | |
| output_indices = model.generate(img_tensor, max_len=MAX_SEQ_LENGTH) | |
| result_words = [] | |
| end_token_idx = vocab.stoi["endofseq"] | |
| for idx in output_indices: | |
| if idx == end_token_idx: | |
| break | |
| word = vocab.itos.get(idx, "unk") | |
| if word not in ["startofseq", "pad", "endofseq"]: | |
| result_words.append(word) | |
| cap = " ".join(result_words) | |
| # Convert tensor (1, 3, H, W) to (H, W, 3) and detach from graph | |
| image_np = img_tensor.squeeze(0).permute(1, 2, 0).cpu().numpy() | |
| image_np = (image_np * 0.5 + 0.5).clip(0, 1) # unnormalize | |
| # Plot the image and caption | |
| plt.figure(figsize=(5, 5)) | |
| plt.imshow(image_np) | |
| plt.axis("off") | |
| plt.title(cap) | |
| # Save the plot to a buffer | |
| buf = BytesIO() | |
| plt.savefig(buf, format='png', bbox_inches='tight') | |
| plt.close() | |
| buf.seek(0) | |
| # Convert buffer to PIL image | |
| pil_img = Image.open(buf) | |
| return pil_img | |
| gr.Interface( | |
| fn=generate_caption_for_image, | |
| inputs=gr.Image(type="pil"), | |
| outputs=gr.Image(type="pil") | |
| ).launch(share=True) |