raun12345678 commited on
Commit
7ad3688
·
verified ·
1 Parent(s): 8dee779

Streamlit_app

Browse files
Files changed (1) hide show
  1. app.py +301 -0
app.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import numpy as np
3
+ import PIL
4
+ import cv2
5
+
6
+
7
+
8
+ import math
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ from torch import optim
13
+ from torch.utils.data import DataLoader
14
+ import torch.nn.functional as F
15
+ from torch.distributions import Categorical
16
+
17
+ import torchvision
18
+ import torchvision.datasets as datasets
19
+ import torchvision.transforms as transforms
20
+
21
+ from transformers import AutoTokenizer
22
+ device = torch.device(0 if torch.cuda.is_available() else 'cpu')
23
+
24
+ def extract_patches(image_tensor, patch_size=16):
25
+ # Get the dimensions of the image tensor
26
+ bs, c, h, w = image_tensor.size()
27
+
28
+ # Define the Unfold layer with appropriate parameters
29
+ unfold = torch.nn.Unfold(kernel_size=patch_size, stride=patch_size)
30
+
31
+ # Apply Unfold to the image tensor
32
+ unfolded = unfold(image_tensor)
33
+
34
+ # Reshape the unfolded tensor to match the desired output shape
35
+ # Output shape: BSxLxH, where L is the number of patches in each dimension
36
+ unfolded = unfolded.transpose(1, 2).reshape(bs, -1, c * patch_size * patch_size)
37
+
38
+ return unfolded
39
+
40
+ # sinusoidal positional embeds
41
+ class SinusoidalPosEmb(nn.Module):
42
+ def __init__(self, dim):
43
+ super().__init__()
44
+ self.dim = dim
45
+
46
+ def forward(self, x):
47
+ device = x.device
48
+ half_dim = self.dim // 2
49
+ emb = math.log(10000) / (half_dim - 1)
50
+ emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
51
+ emb = x[:, None] * emb[None, :]
52
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
53
+ return emb
54
+
55
+
56
+ # Define a module for attention blocks
57
+ class AttentionBlock(nn.Module):
58
+ def __init__(self, hidden_size=128, num_heads=4, masking=True):
59
+ super(AttentionBlock, self).__init__()
60
+ self.masking = masking
61
+
62
+ # Multi-head attention mechanism
63
+ self.multihead_attn = nn.MultiheadAttention(hidden_size,
64
+ num_heads=num_heads,
65
+ batch_first=True,
66
+ dropout=0.0)
67
+
68
+ def forward(self, x_in, kv_in, key_mask=None):
69
+ # Apply causal masking if enabled
70
+ if self.masking:
71
+ bs, l, h = x_in.shape
72
+ mask = torch.triu(torch.ones(l, l, device=x_in.device), 1).bool()
73
+ else:
74
+ mask = None
75
+
76
+ # Perform multi-head attention operation
77
+ return self.multihead_attn(x_in, kv_in, kv_in, attn_mask=mask,
78
+ key_padding_mask=key_mask)[0]
79
+
80
+
81
+ # Define a module for a transformer block with self-attention
82
+ # and optional causal masking
83
+ class TransformerBlock(nn.Module):
84
+ def __init__(self, hidden_size=128, num_heads=4, decoder=False, masking=True):
85
+ super(TransformerBlock, self).__init__()
86
+ self.decoder = decoder
87
+
88
+ # Layer normalization for the input
89
+ self.norm1 = nn.LayerNorm(hidden_size)
90
+ # Self-attention mechanism
91
+ self.attn1 = AttentionBlock(hidden_size=hidden_size, num_heads=num_heads,
92
+ masking=masking)
93
+
94
+ # Layer normalization for the output of the first attention layer
95
+ if self.decoder:
96
+ self.norm2 = nn.LayerNorm(hidden_size)
97
+ # Self-attention mechanism for the decoder with no masking
98
+ self.attn2 = AttentionBlock(hidden_size=hidden_size,
99
+ num_heads=num_heads, masking=False)
100
+
101
+ # Layer normalization for the output before the MLP
102
+ self.norm_mlp = nn.LayerNorm(hidden_size)
103
+ # Multi-layer perceptron (MLP)
104
+ self.mlp = nn.Sequential(nn.Linear(hidden_size, hidden_size * 4),
105
+ nn.ELU(),
106
+ nn.Linear(hidden_size * 4, hidden_size))
107
+
108
+ def forward(self, x, input_key_mask=None, cross_key_mask=None, kv_cross=None):
109
+ # Perform self-attention operation
110
+ x = self.attn1(x, x, key_mask=input_key_mask) + x
111
+ x = self.norm1(x)
112
+
113
+ # If decoder, perform additional cross-attention layer
114
+ if self.decoder:
115
+ x = self.attn2(x, kv_cross, key_mask=cross_key_mask) + x
116
+ x = self.norm2(x)
117
+
118
+ # Apply MLP and layer normalization
119
+ x = self.mlp(x) + x
120
+ return self.norm_mlp(x)
121
+
122
+
123
+ # Define a decoder module for the Transformer architecture
124
+ class Decoder(nn.Module):
125
+ def __init__(self, num_emb, hidden_size=128, num_layers=3, num_heads=4):
126
+ super(Decoder, self).__init__()
127
+
128
+ # Create an embedding layer for tokens
129
+ self.embedding = nn.Embedding(num_emb, hidden_size)
130
+ # Initialize the embedding weights
131
+ self.embedding.weight.data = 0.001 * self.embedding.weight.data
132
+
133
+ # Initialize sinusoidal positional embeddings
134
+ self.pos_emb = SinusoidalPosEmb(hidden_size)
135
+
136
+ # Create multiple transformer blocks as layers
137
+ self.blocks = nn.ModuleList([
138
+ TransformerBlock(hidden_size, num_heads,
139
+ decoder=True) for _ in range(num_layers)
140
+ ])
141
+
142
+ # Define a linear layer for output prediction
143
+ self.fc_out = nn.Linear(hidden_size, num_emb)
144
+
145
+ def forward(self, input_seq, encoder_output, input_padding_mask=None,
146
+ encoder_padding_mask=None):
147
+ # Embed the input sequence
148
+ input_embs = self.embedding(input_seq)
149
+ bs, l, h = input_embs.shape
150
+
151
+ # Add positional embeddings to the input embeddings
152
+ seq_indx = torch.arange(l, device=input_seq.device)
153
+ pos_emb = self.pos_emb(seq_indx).reshape(1, l, h).expand(bs, l, h)
154
+ embs = input_embs + pos_emb
155
+
156
+ # Pass the embeddings through each transformer block
157
+ for block in self.blocks:
158
+ embs = block(embs,
159
+ input_key_mask=input_padding_mask,
160
+ cross_key_mask=encoder_padding_mask,
161
+ kv_cross=encoder_output)
162
+
163
+ return self.fc_out(embs)
164
+
165
+
166
+ # Define an Vision Encoder module for the Transformer architecture
167
+ class VisionEncoder(nn.Module):
168
+ def __init__(self, image_size, channels_in, patch_size=16, hidden_size=128,
169
+ num_layers=3, num_heads=4):
170
+ super(VisionEncoder, self).__init__()
171
+
172
+ self.patch_size = patch_size
173
+ self.fc_in = nn.Linear(channels_in * patch_size * patch_size, hidden_size)
174
+
175
+ seq_length = (image_size // patch_size) ** 2
176
+ self.pos_embedding = nn.Parameter(torch.empty(1, seq_length,
177
+ hidden_size).normal_(std=0.02))
178
+
179
+ # Create multiple transformer blocks as layers
180
+ self.blocks = nn.ModuleList([
181
+ TransformerBlock(hidden_size, num_heads,
182
+ decoder=False, masking=False) for _ in range(num_layers)
183
+ ])
184
+
185
+ def forward(self, image):
186
+ bs = image.shape[0]
187
+
188
+ patch_seq = extract_patches(image, patch_size=self.patch_size)
189
+ patch_emb = self.fc_in(patch_seq)
190
+
191
+ # Add a unique embedding to each token embedding
192
+ embs = patch_emb + self.pos_embedding
193
+
194
+ # Pass the embeddings through each transformer block
195
+ for block in self.blocks:
196
+ embs = block(embs)
197
+
198
+ return embs
199
+
200
+
201
+ # Define an Vision Encoder-Decoder module for the Transformer architecture
202
+ class VisionEncoderDecoder(nn.Module):
203
+ def __init__(self, image_size, channels_in, num_emb, patch_size=16,
204
+ hidden_size=128, num_layers=(3, 3), num_heads=4):
205
+ super(VisionEncoderDecoder, self).__init__()
206
+
207
+ # Create an encoder and decoder with specified parameters
208
+ self.encoder = VisionEncoder(image_size=image_size, channels_in=channels_in,
209
+ patch_size=patch_size, hidden_size=hidden_size,
210
+ num_layers=num_layers[0], num_heads=num_heads)
211
+
212
+ self.decoder = Decoder(num_emb=num_emb, hidden_size=hidden_size,
213
+ num_layers=num_layers[1], num_heads=num_heads)
214
+
215
+ def forward(self, input_image, target_seq, padding_mask):
216
+ # Generate padding masks for the target sequence
217
+ bool_padding_mask = padding_mask == 0
218
+
219
+ # Encode the input sequence
220
+ encoded_seq = self.encoder(image=input_image)
221
+
222
+ # Decode the target sequence using the encoded sequence
223
+ decoded_seq = self.decoder(input_seq=target_seq,
224
+ encoder_output=encoded_seq,
225
+ input_padding_mask=bool_padding_mask)
226
+ return decoded_seq
227
+
228
+ model = torch.load("caption_model.pth", weights_only=False)
229
+ model.eval()
230
+ tokenizer=AutoTokenizer.from_pretrained("distilbert-base-uncased")
231
+
232
+ def pred_transformer_caption(test_img):
233
+
234
+
235
+
236
+
237
+
238
+
239
+ # Add the Start-Of-Sentence token to the prompt to signal the network to start generating the caption
240
+ sos_token = 101 * torch.ones(1, 1).long()
241
+
242
+ # Set the temperature for sampling during generation
243
+ temp = 0.5
244
+
245
+ log_tokens = [sos_token]
246
+ model.eval()
247
+
248
+ with torch.no_grad():
249
+ # Encode the input image
250
+ with torch.cuda.amp.autocast():
251
+ # Forward pass
252
+ image_embedding = model.encoder(test_img.to(device))
253
+
254
+ # Generate the answer tokens
255
+ for i in range(50):
256
+ input_tokens = torch.cat(log_tokens, 1)
257
+
258
+ # Decode the input tokens into the next predicted tokens
259
+ data_pred = model.decoder(input_tokens.to(device), image_embedding)
260
+
261
+ # Sample from the distribution of predicted probabilities
262
+ dist = Categorical(logits=data_pred[:, -1] / temp)
263
+ next_tokens = dist.sample().reshape(1, 1)
264
+
265
+ # Append the next predicted token to the sequence
266
+ log_tokens.append(next_tokens.cpu())
267
+
268
+ # Break the loop if the End-Of-Caption token is predicted
269
+ if next_tokens.item() == 102:
270
+ break
271
+
272
+ # Convert the list of token indices to a tensor
273
+ pred_text = torch.cat(log_tokens, 1)
274
+
275
+ # Convert the token indices to their corresponding strings using the vocabulary
276
+ pred_text_strings = tokenizer.decode(pred_text[0], skip_special_tokens=True)
277
+
278
+ # Join the token strings to form the predicted text
279
+ pred_text = "".join(pred_text_strings)
280
+
281
+ # Print the predicted text
282
+ return (pred_text)
283
+
284
+ ##Dashboard
285
+
286
+ st.title("Caption_APP")
287
+ test_img=st.file_uploader(label="upload the funny pic :) :", type=["png","jpg","jpeg"])
288
+ caption=""
289
+ if test_img:
290
+
291
+ test_img=PIL.Image.open(test_img)
292
+ test_img=test_img.resize((128,128))
293
+ test_img=((test_img-np.amin(test_img))/(np.amax(test_img)-np.amin(test_img)))
294
+ test_img=np.array(test_img)
295
+ test_img=test_img.reshape((1,)+test_img.shape)
296
+ test_img=test_img.astype("float32")
297
+ copy=test_img
298
+ test_img=torch.from_numpy(test_img).to(device).unsqueeze(0)
299
+ caption=(str)(pred_transformer_caption(test_img))
300
+ st.image(image=np.squeeze(copy),caption=caption)
301
+ #st.write(caption)