File size: 6,520 Bytes
2713ac2
 
 
 
 
 
 
 
 
 
 
 
 
 
e1a9146
2713ac2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e1a9146
2713ac2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
import os, torch, pickle, re
from io import BytesIO
from torchvision import models, transforms
from matplotlib import pyplot as plt
from torch import nn
from collections import Counter
from PIL import Image
import gradio as gr
from huggingface_hub import snapshot_download

EMBED_DIM = 256
HIDDEN_DIM = 512
MAX_SEQ_LENGTH = 25
VOCAB_SIZE = 8492
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transform_inference = transforms.Compose([
    transforms.Resize((384, 384)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.5, 0.5, 0.5],
        std=[0.5, 0.5, 0.5]
    )
])

class Vocabulary:
  def __init__(self, freq_threshold=5):
      self.freq_threshold = freq_threshold
      # self.itos = {0: "<pad>", 1: "<start>", 2: "<end>", 3: "<unk>"}
      self.itos = {0: "pad", 1: "startofseq", 2: "endofseq", 3: "unk"}
      self.stoi = {v: k for k, v in self.itos.items()}
      self.index = 4
 
  def __len__(self):
      return len(self.itos)
 
  def tokenizer(self, text):
      text = text.lower()
      tokens = re.findall(r"\w+", text)
      return tokens
 
  def build_vocabulary(self, sentence_list):
      frequencies = Counter()
      for sentence in sentence_list:
          tokens = self.tokenizer(sentence)
          frequencies.update(tokens)
 
      for word, freq in frequencies.items():
          if freq >= self.freq_threshold:
              self.stoi[word] = self.index
              self.itos[self.index] = word
              self.index += 1
 
  def numericalize(self, text):
      tokens = self.tokenizer(text)
      numericalized = []
      for token in tokens:
          if token in self.stoi:
              numericalized.append(self.stoi[token])
          else:
              numericalized.append(self.stoi["<unk>"])
      return numericalized

class ViTEncoder(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        # Load pretrained ViT
        weights = models.ViT_B_16_Weights.IMAGENET1K_SWAG_E2E_V1  # High-quality pretrained weights
        vit = models.vit_b_16(weights=weights)

        # Remove classification head
        self.vit = vit
        self.vit.heads = nn.Identity()

        # Optional: fine-tune ViT
        for param in self.vit.parameters():
            param.requires_grad = False  # Set to False if you want to freeze the encoder

        # Projection to embedding dim for decoder
        self.fc = nn.Linear(self.vit.hidden_dim, embed_dim)
        self.batch_norm = nn.BatchNorm1d(embed_dim, momentum=0.01)

    def forward(self, images):
        # images: (B, 3, H, W)
        features = self.vit(images)  # (B, vit.hidden_dim)
        features = self.fc(features)  # (B, embed_dim)
        features = self.batch_norm(features)
        return features

class DecoderLSTM(nn.Module):
    def __init__(self, embed_dim, hidden_dim, vocab_size, num_layers=1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)
        self.vocab_size = vocab_size

    def forward(self, features, captions, states):
        embeddings = self.embedding(captions)
        inputs = torch.cat((features.unsqueeze(1), embeddings), dim=1)
        lstm_out, states = self.lstm(inputs, states)
        logits = self.fc(lstm_out)
        return logits, states
 
    def generate(self, features, max_len=20): # changed
        batch_size = features.size(0)
        states = None
        generated_captions = []
 
        start_idx = 1  # startofseq
        end_idx = 2  # endofseq
        current_tokens = [start_idx]
 
        for _ in range(max_len):
            input_tokens = torch.LongTensor(current_tokens).to(features.device).unsqueeze(0)
            logits, states = self.forward(features, input_tokens, states)
            logits = logits.contiguous().view(-1, VOCAB_SIZE)
            predicted = logits.argmax(dim=1)[-1].item()
 
            generated_captions.append(predicted)
            current_tokens.append(predicted)
 
        return generated_captions

class ImageCaptioningModel(nn.Module):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
 
    def generate(self, images, max_len=MAX_SEQ_LENGTH): # changed
        features = self.encoder(images)
        return self.decoder.generate(features, max_len=max_len)

def load_model_and_vocab(repo_id):
    download_dir = snapshot_download(repo_id)
    print(download_dir)
    model_path = os.path.join(download_dir, "best_finetuned_infer.pth")
    vocab_path = os.path.join(download_dir, "vocab.pkl")
    
    encoder = ViTEncoder(embed_dim=EMBED_DIM)
    decoder = DecoderLSTM(EMBED_DIM, HIDDEN_DIM, VOCAB_SIZE)
    model = ImageCaptioningModel(encoder, decoder).to(DEVICE)
    
    state_dict = torch.load(model_path, map_location=DEVICE)
    model.load_state_dict(state_dict['model_state_dict'])
    model.eval()
    
    with open(vocab_path, 'rb') as f:
        vocab = pickle.load(f)
    
    return model, vocab

model, vocab = load_model_and_vocab("prakhartrivedi/ImageCaptioningSCCCI")
print("Model and vocabulary loaded successfully.")

def generate_caption_for_image(img):
    pil_img = img.convert("RGB")
    img_tensor = transform_inference(pil_img).unsqueeze(0).to(DEVICE)
 
    with torch.no_grad():
        output_indices = model.generate(img_tensor, max_len=MAX_SEQ_LENGTH)
 
    result_words = []
    end_token_idx = vocab.stoi["endofseq"]
    for idx in output_indices:
        if idx == end_token_idx:
            break
        word = vocab.itos.get(idx, "unk")
        if word not in ["startofseq", "pad", "endofseq"]:
            result_words.append(word)
    cap = " ".join(result_words)
    
    # Convert tensor (1, 3, H, W) to (H, W, 3) and detach from graph
    image_np = img_tensor.squeeze(0).permute(1, 2, 0).cpu().numpy()
    image_np = (image_np * 0.5 + 0.5).clip(0, 1)  # unnormalize
    
    # Plot the image and caption
    plt.figure(figsize=(5, 5))
    plt.imshow(image_np)
    plt.axis("off")
    plt.title(cap)

    # Save the plot to a buffer
    buf = BytesIO()
    plt.savefig(buf, format='png', bbox_inches='tight')
    plt.close()
    buf.seek(0)
    
    # Convert buffer to PIL image
    pil_img = Image.open(buf)
    return pil_img

gr.Interface(
    fn=generate_caption_for_image,
    inputs=gr.Image(type="pil"),
    outputs=gr.Image(type="pil")
).launch(share=True)