LeahLv commited on
Commit
7ac94ec
·
1 Parent(s): 9613024

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +323 -0
app.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Databricks notebook source
2
+
3
+ import tensorflow as tf
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ from PIL import Image
7
+ #from turtle import width
8
+ import streamlit as st
9
+
10
+
11
+ # COMMAND ----------
12
+
13
+ def load_image_initial(image_file):
14
+ img = Image.open(image_file)
15
+ return img
16
+
17
+
18
+ #streamlit
19
+ header = st.container()
20
+ image = st.container()
21
+ caption = st.container()
22
+
23
+ with header:
24
+ st.title('Image Captioning')
25
+ st.text('Generate captions for your images!')
26
+
27
+ with image:
28
+ # st.markdown("**upload your image here:**")
29
+ image_file = st.file_uploader("upload your image here:", type = ["png", "jpg", 'jpeg'])
30
+ if image_file is not None:
31
+ #st.write(type(image_file))
32
+ # st.write(dir(image_file))
33
+ # file_details = {"filename": image_file.name, "filetype":image_file.type, "filesize":image_file.size}
34
+ # st.write(file_details)
35
+ st.image(load_image_initial(image_file), width=299)
36
+
37
+
38
+
39
+
40
+ ################################model 14
41
+ num_predictions = 3
42
+ feature_extraction_model = 'ResNet50'
43
+ tokenizer_path = 'tokenizer.pkl'
44
+ # checkpoint_path = "/dbfs/FileStore/shared_uploads/mhajiza@gap.com/computer_vision/models/image_captioning_tf_14/ckpt-10"
45
+ # checkpoint_path = "/dbfs/FileStore/shared_uploads/mhajiza@gap.com/computer_vision/models/image_captioning_tf_14/manually_saved_model-11"
46
+ # checkpoint_path = "/Users/mhajiza/Documents/Computer_Vison/Image_captioning/image_captioning_tf_model/ckpt-10"
47
+ checkpoint_path = "ckpt-10"
48
+ weights= "checkpoint"
49
+
50
+ # checkpoint_path = "/Users/mhajiza/Documents/Computer_Vison/Image_captioning/image_captioning_tf_model/manually_saved_model-11"
51
+
52
+ # COMMAND ----------
53
+
54
+ def load_image(image_file):
55
+ img = Image.open(image_file).convert('RGB')
56
+ img = tf.keras.preprocessing.image.img_to_array(img)
57
+ img = tf.keras.layers.Resizing(299, 299)(img)
58
+ if feature_extraction_model == 'InceptionV3':
59
+ img = tf.keras.applications.inception_v3.preprocess_input(img)
60
+ if (feature_extraction_model == 'ResNet50') or (feature_extraction_model == 'ResNet101') or (feature_extraction_model == 'ResNet152'):
61
+ img = tf.keras.applications.resnet.preprocess_input(img)
62
+ return img, image_file
63
+
64
+ # COMMAND ----------
65
+
66
+
67
+ #Initialize ResNet and load the pretrained Imagenet weights
68
+ if feature_extraction_model == 'ResNet152':
69
+ image_model = tf.keras.applications.ResNet152(include_top=False, weights=weights)
70
+ new_input = image_model.input
71
+ hidden_layer = image_model.layers[-1].output
72
+ image_features_extract_model = tf.keras.Model(new_input, hidden_layer)
73
+ if feature_extraction_model == 'ResNet50':
74
+ image_model = tf.keras.applications.ResNet50(include_top=False, weights=weights)
75
+ new_input = image_model.input
76
+ hidden_layer = image_model.layers[-1].output
77
+ image_features_extract_model = tf.keras.Model(new_input, hidden_layer)
78
+ if feature_extraction_model == 'ResNet101':
79
+ image_model = tf.keras.applications.ResNet101(include_top=False, weights=weights)
80
+ new_input = image_model.input
81
+ hidden_layer = image_model.layers[-1].output
82
+ image_features_extract_model = tf.keras.Model(new_input, hidden_layer)
83
+ if feature_extraction_model == 'InceptionV3':
84
+ image_model = tf.keras.applications.InceptionV3(include_top=False, weights=weights)
85
+ new_input = image_model.input
86
+ hidden_layer = image_model.layers[-1].output
87
+ image_features_extract_model = tf.keras.Model(new_input, hidden_layer)
88
+
89
+ # COMMAND ----------
90
+
91
+
92
+ def standardize(inputs):
93
+ inputs = tf.strings.lower(inputs)
94
+ return tf.strings.regex_replace(inputs, r"!\"#$%&\(\)\*\+.,-/:;=?@\[\\\]^_`{|}~", "")
95
+ import pickle
96
+ from tensorflow.keras.layers.experimental.preprocessing import TextVectorization
97
+ from_disk = pickle.load(open(tokenizer_path, "rb"))
98
+ tokenizer = TextVectorization.from_config(from_disk['config'])
99
+ tokenizer.adapt(["this is a test"])
100
+ tokenizer.set_weights(from_disk['weights'])
101
+
102
+ # COMMAND ----------
103
+
104
+ vocabulary_size = tokenizer.get_config()['max_tokens']
105
+ max_length = tokenizer.get_config()['output_sequence_length']
106
+
107
+ # COMMAND ----------
108
+
109
+ # Create mappings for words to indices and indices to words.
110
+ word_to_index = tf.keras.layers.StringLookup(mask_token="", vocabulary=tokenizer.get_vocabulary())
111
+ index_to_word = tf.keras.layers.StringLookup( mask_token="", vocabulary=tokenizer.get_vocabulary(), invert=True)
112
+
113
+ # COMMAND ----------
114
+
115
+ # max_length = 95 ##100
116
+ embedding_dim = 256
117
+ units = 512
118
+ # Shape of the vector extracted from InceptionV3 is (64, 2048)
119
+ # These two variables represent that vector shape
120
+ features_shape = 2048
121
+ attention_features_shape = 64
122
+
123
+ # COMMAND ----------
124
+
125
+ class BahdanauAttention(tf.keras.Model): #####Attention mechanism
126
+ def __init__(self, units):
127
+ super(BahdanauAttention, self).__init__()
128
+ self.W1 = tf.keras.layers.Dense(units)
129
+ self.W2 = tf.keras.layers.Dense(units)
130
+ self.V = tf.keras.layers.Dense(1)
131
+
132
+ def call(self, features, hidden):
133
+ # features(CNN_encoder output) shape == (batch_size, 64, embedding_dim) ######(batch_size, 64, 2048)
134
+
135
+ # hidden shape == (batch_size, hidden_size)
136
+ # hidden_with_time_axis shape == (batch_size, 1, hidden_size) ##### this is after expanding with axis =1
137
+ hidden_with_time_axis = tf.expand_dims(hidden, 1)
138
+
139
+ # attention_hidden_layer shape == (batch_size, 64, units)
140
+ attention_hidden_layer = (tf.nn.tanh(self.W1(features) +
141
+ self.W2(hidden_with_time_axis)))
142
+
143
+ # score shape == (batch_size, 64, 1)
144
+ # This gives you an unnormalized score for each image feature.
145
+ score = self.V(attention_hidden_layer)
146
+
147
+ # attention_weights shape == (batch_size, 64, 1)
148
+ attention_weights = tf.nn.softmax(score, axis=1)
149
+
150
+ # context_vector shape after sum == (batch_size, hidden_size)
151
+ context_vector = attention_weights * features
152
+ context_vector = tf.reduce_sum(context_vector, axis=1)
153
+
154
+ return context_vector, attention_weights
155
+
156
+ # COMMAND ----------
157
+
158
+ class CNN_Encoder(tf.keras.Model):
159
+ # Since you have already extracted the features and dumped it
160
+ # This encoder passes those features through a Fully connected layer
161
+ def __init__(self, embedding_dim):
162
+ super(CNN_Encoder, self).__init__()
163
+ # shape after fc == (batch_size, 64, embedding_dim)
164
+ self.fc = tf.keras.layers.Dense(embedding_dim)
165
+
166
+ def call(self, x):
167
+ x = self.fc(x)
168
+ x = tf.nn.relu(x)
169
+ return x
170
+
171
+ # COMMAND ----------
172
+
173
+ class RNN_Decoder(tf.keras.Model):
174
+ def __init__(self, embedding_dim, units, vocab_size):
175
+ super(RNN_Decoder, self).__init__()
176
+ self.units = units
177
+
178
+ self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
179
+ self.gru = tf.keras.layers.GRU(self.units,
180
+ return_sequences=True,
181
+ return_state=True,
182
+ recurrent_initializer='glorot_uniform')
183
+ self.fc1 = tf.keras.layers.Dense(self.units)
184
+ self.fc2 = tf.keras.layers.Dense(vocab_size)
185
+
186
+ self.attention = BahdanauAttention(self.units)
187
+
188
+ def call(self, x, features, hidden):
189
+ # defining attention as a separate model
190
+ context_vector, attention_weights = self.attention(features, hidden)
191
+
192
+ # x shape after passing through embedding == (batch_size, 1, embedding_dim)
193
+ x = self.embedding(x)
194
+
195
+ # x shape after concatenation == (batch_size, 1, embedding_dim + hidden_size)
196
+ x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)
197
+
198
+ # passing the concatenated vector to the GRU
199
+ output, state = self.gru(x)
200
+
201
+ # shape == (batch_size, max_length, hidden_size)
202
+ x = self.fc1(output)
203
+
204
+ # x shape == (batch_size * max_length, hidden_size)
205
+ x = tf.reshape(x, (-1, x.shape[2]))
206
+
207
+ # output shape == (batch_size * max_length, vocab)
208
+ x = self.fc2(x)
209
+
210
+ return x, state, attention_weights
211
+
212
+ def reset_state(self, batch_size):
213
+ return tf.zeros((batch_size, self.units))
214
+
215
+ # COMMAND ----------
216
+
217
+ encoder = CNN_Encoder(embedding_dim)
218
+ decoder = RNN_Decoder(embedding_dim, units, tokenizer.vocabulary_size())
219
+
220
+ # COMMAND ----------
221
+
222
+ optimizer = tf.keras.optimizers.Adam()
223
+ loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
224
+ from_logits=True, reduction='none')
225
+
226
+
227
+ def loss_function(real, pred):
228
+ mask = tf.math.logical_not(tf.math.equal(real, 0))
229
+ loss_ = loss_object(real, pred)
230
+
231
+ mask = tf.cast(mask, dtype=loss_.dtype)
232
+ loss_ *= mask
233
+
234
+ return tf.reduce_mean(loss_)
235
+
236
+ # COMMAND ----------
237
+
238
+ ckpt = tf.train.Checkpoint(encoder=encoder,
239
+ decoder=decoder,
240
+ optimizer=optimizer)
241
+ ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=2)
242
+ # ckpt.restore(ckpt_manager.latest_checkpoint)
243
+
244
+ # COMMAND ----------
245
+
246
+ ckpt.restore(checkpoint_path)
247
+
248
+ # COMMAND ----------
249
+
250
+ def evaluate(image):
251
+ # attention_plot = np.zeros((max_length, attention_features_shape))
252
+ attention_plot = np.zeros((max_length, 100))
253
+
254
+
255
+ hidden = decoder.reset_state(batch_size=1)
256
+
257
+ temp_input = tf.expand_dims(load_image(image)[0], 0)
258
+ img_tensor_val = image_features_extract_model(temp_input)
259
+ # print(img_tensor_val.shape)
260
+ img_tensor_val = tf.reshape(img_tensor_val, (img_tensor_val.shape[0],
261
+ -1,
262
+ img_tensor_val.shape[3]))
263
+ # print(img_tensor_val.shape)
264
+ features = encoder(img_tensor_val)
265
+ # print(features.shape)
266
+ dec_input = tf.expand_dims([word_to_index('<start>')], 0)
267
+ result = []
268
+
269
+ for i in range(max_length):
270
+ predictions, hidden, attention_weights = decoder(dec_input,
271
+ features,
272
+ hidden)
273
+
274
+ attention_plot[i] = tf.reshape(attention_weights, (-1, )).numpy()
275
+
276
+ predicted_id = tf.random.categorical(predictions, 1)[0][0].numpy()
277
+ predicted_word = tf.compat.as_text(index_to_word(predicted_id).numpy())
278
+ result.append(predicted_word)
279
+
280
+ if predicted_word == '<end>':
281
+ return result, attention_plot
282
+
283
+ dec_input = tf.expand_dims([predicted_id], 0)
284
+
285
+ attention_plot = attention_plot[:len(result), :]
286
+ return result, attention_plot
287
+
288
+ # COMMAND ----------
289
+
290
+ def plot_attention(image, result, attention_plot):
291
+ temp_image = np.array(Image.open(image))
292
+
293
+ fig = plt.figure(figsize=(30, 30))
294
+
295
+ len_result = len(result)
296
+ for i in range(len_result):
297
+ temp_att = np.resize(attention_plot[i], (8, 8))
298
+ grid_size = max(int(np.ceil(len_result/2)), 2)
299
+ ax = fig.add_subplot(grid_size, grid_size, i+1)
300
+ ax.set_title(result[i])
301
+ img = ax.imshow(temp_image)
302
+ ax.imshow(temp_att, cmap='gray', alpha=0.6, extent=img.get_extent())
303
+
304
+ plt.tight_layout()
305
+ plt.show()
306
+
307
+ # COMMAND ----------
308
+
309
+
310
+ if image_file is not None:
311
+ with caption:
312
+ st.header("generated captions by model:")
313
+ for i in range(1, num_predictions+1):
314
+ p = st.empty()
315
+ result, _ = evaluate(image_file)
316
+ pred = ' '.join(result)
317
+ p.write(f"**caption {i}**: {pred}")
318
+ # st.header("**caption**")
319
+ # st.text(pred)
320
+
321
+
322
+
323
+