UtkarshShivhare commited on
Commit
ec85771
·
1 Parent(s): e3aee9e

Upload Inage Captioning.ipynb

Browse files
Files changed (1) hide show
  1. Inage Captioning.ipynb +972 -0
Inage Captioning.ipynb ADDED
@@ -0,0 +1,972 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "113985e3",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stderr",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "C:\\Users\\utkar\\anaconda4\\lib\\site-packages\\scipy\\__init__.py:138: UserWarning: A NumPy version >=1.16.5 and <1.23.0 is required for this version of SciPy (detected version 1.23.5)\n",
14
+ " warnings.warn(f\"A NumPy version >={np_minversion} and <{np_maxversion} is required for this version of \"\n"
15
+ ]
16
+ }
17
+ ],
18
+ "source": [
19
+ "import pickle\n",
20
+ "from tqdm.notebook import tqdm\n",
21
+ "import os\n",
22
+ "import pandas \n",
23
+ "import numpy as np\n",
24
+ "from tensorflow.keras.applications.vgg16 import VGG16,preprocess_input\n",
25
+ "from tensorflow.keras.preprocessing.image import load_img,img_to_array\n",
26
+ "from tensorflow.keras.preprocessing.text import Tokenizer\n",
27
+ "from tensorflow.keras.preprocessing.sequence import pad_sequences\n",
28
+ "from tensorflow.keras.models import Model\n",
29
+ "from tensorflow.keras.utils import to_categorical,plot_model\n",
30
+ "from tensorflow.keras.layers import Input,Dense,LSTM,Embedding, Dropout, add"
31
+ ]
32
+ },
33
+ {
34
+ "cell_type": "code",
35
+ "execution_count": 2,
36
+ "id": "6f9ba09d",
37
+ "metadata": {},
38
+ "outputs": [],
39
+ "source": [
40
+ "work=\"C:\\crawlers\\Project_hastag\\save\"\n",
41
+ "base=\"C:\\crawlers\\Project_hastag\\Archive\""
42
+ ]
43
+ },
44
+ {
45
+ "cell_type": "code",
46
+ "execution_count": 3,
47
+ "id": "204bf9d6",
48
+ "metadata": {},
49
+ "outputs": [
50
+ {
51
+ "name": "stdout",
52
+ "output_type": "stream",
53
+ "text": [
54
+ "Model: \"model\"\n",
55
+ "_________________________________________________________________\n",
56
+ " Layer (type) Output Shape Param # \n",
57
+ "=================================================================\n",
58
+ " input_1 (InputLayer) [(None, 224, 224, 3)] 0 \n",
59
+ " \n",
60
+ " block1_conv1 (Conv2D) (None, 224, 224, 64) 1792 \n",
61
+ " \n",
62
+ " block1_conv2 (Conv2D) (None, 224, 224, 64) 36928 \n",
63
+ " \n",
64
+ " block1_pool (MaxPooling2D) (None, 112, 112, 64) 0 \n",
65
+ " \n",
66
+ " block2_conv1 (Conv2D) (None, 112, 112, 128) 73856 \n",
67
+ " \n",
68
+ " block2_conv2 (Conv2D) (None, 112, 112, 128) 147584 \n",
69
+ " \n",
70
+ " block2_pool (MaxPooling2D) (None, 56, 56, 128) 0 \n",
71
+ " \n",
72
+ " block3_conv1 (Conv2D) (None, 56, 56, 256) 295168 \n",
73
+ " \n",
74
+ " block3_conv2 (Conv2D) (None, 56, 56, 256) 590080 \n",
75
+ " \n",
76
+ " block3_conv3 (Conv2D) (None, 56, 56, 256) 590080 \n",
77
+ " \n",
78
+ " block3_pool (MaxPooling2D) (None, 28, 28, 256) 0 \n",
79
+ " \n",
80
+ " block4_conv1 (Conv2D) (None, 28, 28, 512) 1180160 \n",
81
+ " \n",
82
+ " block4_conv2 (Conv2D) (None, 28, 28, 512) 2359808 \n",
83
+ " \n",
84
+ " block4_conv3 (Conv2D) (None, 28, 28, 512) 2359808 \n",
85
+ " \n",
86
+ " block4_pool (MaxPooling2D) (None, 14, 14, 512) 0 \n",
87
+ " \n",
88
+ " block5_conv1 (Conv2D) (None, 14, 14, 512) 2359808 \n",
89
+ " \n",
90
+ " block5_conv2 (Conv2D) (None, 14, 14, 512) 2359808 \n",
91
+ " \n",
92
+ " block5_conv3 (Conv2D) (None, 14, 14, 512) 2359808 \n",
93
+ " \n",
94
+ " block5_pool (MaxPooling2D) (None, 7, 7, 512) 0 \n",
95
+ " \n",
96
+ " flatten (Flatten) (None, 25088) 0 \n",
97
+ " \n",
98
+ " fc1 (Dense) (None, 4096) 102764544 \n",
99
+ " \n",
100
+ " fc2 (Dense) (None, 4096) 16781312 \n",
101
+ " \n",
102
+ "=================================================================\n",
103
+ "Total params: 134,260,544\n",
104
+ "Trainable params: 134,260,544\n",
105
+ "Non-trainable params: 0\n",
106
+ "_________________________________________________________________\n"
107
+ ]
108
+ }
109
+ ],
110
+ "source": [
111
+ "model=VGG16()\n",
112
+ "model=Model(model.inputs,outputs=model.layers[-2].output)\n",
113
+ "model.summary()"
114
+ ]
115
+ },
116
+ {
117
+ "cell_type": "code",
118
+ "execution_count": 4,
119
+ "id": "22708632",
120
+ "metadata": {},
121
+ "outputs": [
122
+ {
123
+ "data": {
124
+ "application/vnd.jupyter.widget-view+json": {
125
+ "model_id": "908848215aa6423a84b9e8398a2da55b",
126
+ "version_major": 2,
127
+ "version_minor": 0
128
+ },
129
+ "text/plain": [
130
+ " 0%| | 0/8091 [00:00<?, ?it/s]"
131
+ ]
132
+ },
133
+ "metadata": {},
134
+ "output_type": "display_data"
135
+ }
136
+ ],
137
+ "source": [
138
+ "# Feature image\n",
139
+ "fs={}\n",
140
+ "directory=os.path.join(base,'Images')\n",
141
+ "\n",
142
+ "for img in tqdm(os.listdir(directory)):\n",
143
+ " img_name=os.path.join(directory,img)\n",
144
+ " image=load_img(img_name,target_size=(224,224))\n",
145
+ " image=img_to_array(image)\n",
146
+ " image=image.reshape(1,image.shape[0],image.shape[1],image.shape[2])\n",
147
+ " image=preprocess_input(image)\n",
148
+ " f=model.predict(image,verbose=0)\n",
149
+ " im_id=img.split(\".\")[0]\n",
150
+ " fs[im_id]=f\n"
151
+ ]
152
+ },
153
+ {
154
+ "cell_type": "code",
155
+ "execution_count": 5,
156
+ "id": "195b8d10",
157
+ "metadata": {},
158
+ "outputs": [],
159
+ "source": [
160
+ "pickle.dump(fs,open(os.path.join(work,\"features.pkl\"),\"wb\"))\n"
161
+ ]
162
+ },
163
+ {
164
+ "cell_type": "code",
165
+ "execution_count": 4,
166
+ "id": "7c0d0727",
167
+ "metadata": {},
168
+ "outputs": [],
169
+ "source": [
170
+ "with open (os.path.join(work,\"features.pkl\"),\"rb\") as f:\n",
171
+ " features=pickle.load(f)"
172
+ ]
173
+ },
174
+ {
175
+ "cell_type": "code",
176
+ "execution_count": 6,
177
+ "id": "f4425b39",
178
+ "metadata": {},
179
+ "outputs": [],
180
+ "source": [
181
+ "with open(os.path.join(base, 'captions.txt'), 'r') as f:\n",
182
+ " next(f)\n",
183
+ " captions_doc = f.read()"
184
+ ]
185
+ },
186
+ {
187
+ "cell_type": "code",
188
+ "execution_count": 7,
189
+ "id": "9eeb446f",
190
+ "metadata": {},
191
+ "outputs": [
192
+ {
193
+ "data": {
194
+ "application/vnd.jupyter.widget-view+json": {
195
+ "model_id": "eccf97da2e1744378f9e03922dc4bc7b",
196
+ "version_major": 2,
197
+ "version_minor": 0
198
+ },
199
+ "text/plain": [
200
+ " 0%| | 0/40456 [00:00<?, ?it/s]"
201
+ ]
202
+ },
203
+ "metadata": {},
204
+ "output_type": "display_data"
205
+ }
206
+ ],
207
+ "source": [
208
+ "ma={}\n",
209
+ "data=caption_data.split(\"\\n\")\n",
210
+ "for line in tqdm(data):\n",
211
+ " mapp=line.split(\",\")\n",
212
+ " if len(mapp)<2:\n",
213
+ " continue\n",
214
+ " im_id=mapp[0]\n",
215
+ " cap=mapp[1]\n",
216
+ " cap=\"\".join(cap)\n",
217
+ " im_id=im_id.split(\".\")[0]\n",
218
+ " if im_id not in ma:\n",
219
+ " ma[im_id]=[]\n",
220
+ " ma[im_id].append(cap)"
221
+ ]
222
+ },
223
+ {
224
+ "cell_type": "code",
225
+ "execution_count": 8,
226
+ "id": "d4621146",
227
+ "metadata": {},
228
+ "outputs": [],
229
+ "source": [
230
+ "def process_text(cap):\n",
231
+ " cap=cap.lower()\n",
232
+ " cap=cap.replace('[^a-z]',\"\")\n",
233
+ " cap=cap.replace('\\s+',\" \")\n",
234
+ " cap=\"startseq \"+\" \".join([word for word in cap.split(\" \") if len(word)>1])+\" endseq\"\n",
235
+ " return cap\n",
236
+ "\n",
237
+ "def clean(ma):\n",
238
+ " for key, cap in ma.items():\n",
239
+ " for i in range(len(cap)):\n",
240
+ " cap[i]=process_text(cap[i])\n",
241
+ "\n"
242
+ ]
243
+ },
244
+ {
245
+ "cell_type": "code",
246
+ "execution_count": 9,
247
+ "id": "249084e3",
248
+ "metadata": {},
249
+ "outputs": [],
250
+ "source": [
251
+ "clean(ma)"
252
+ ]
253
+ },
254
+ {
255
+ "cell_type": "code",
256
+ "execution_count": 10,
257
+ "id": "82d3499a",
258
+ "metadata": {},
259
+ "outputs": [],
260
+ "source": [
261
+ "all_captions = []\n",
262
+ "for key in mapping:\n",
263
+ " for caption in mapping[key]:\n",
264
+ " all_captions.append(caption)"
265
+ ]
266
+ },
267
+ {
268
+ "cell_type": "code",
269
+ "execution_count": 11,
270
+ "id": "cd90e55f",
271
+ "metadata": {},
272
+ "outputs": [
273
+ {
274
+ "data": {
275
+ "text/plain": [
276
+ "40455"
277
+ ]
278
+ },
279
+ "execution_count": 11,
280
+ "metadata": {},
281
+ "output_type": "execute_result"
282
+ }
283
+ ],
284
+ "source": [
285
+ "len(all_captions)"
286
+ ]
287
+ },
288
+ {
289
+ "cell_type": "code",
290
+ "execution_count": 12,
291
+ "id": "b195a348",
292
+ "metadata": {},
293
+ "outputs": [],
294
+ "source": [
295
+ "tokenizer = Tokenizer()\n",
296
+ "tokenizer.fit_on_texts(all_captions)\n",
297
+ "vocab_size = len(tokenizer.word_index) + 1"
298
+ ]
299
+ },
300
+ {
301
+ "cell_type": "code",
302
+ "execution_count": 13,
303
+ "id": "06788c74",
304
+ "metadata": {},
305
+ "outputs": [
306
+ {
307
+ "data": {
308
+ "text/plain": [
309
+ "35"
310
+ ]
311
+ },
312
+ "execution_count": 13,
313
+ "metadata": {},
314
+ "output_type": "execute_result"
315
+ }
316
+ ],
317
+ "source": [
318
+ "max_length = max(len(caption.split()) for caption in all_captions)\n",
319
+ "max_length"
320
+ ]
321
+ },
322
+ {
323
+ "cell_type": "code",
324
+ "execution_count": 14,
325
+ "id": "6edafc86",
326
+ "metadata": {},
327
+ "outputs": [],
328
+ "source": [
329
+ "image_ids = list(mapping.keys())\n",
330
+ "split = int(len(image_ids) * 0.90)\n",
331
+ "train = image_ids[:split]\n",
332
+ "test = image_ids[split:]"
333
+ ]
334
+ },
335
+ {
336
+ "cell_type": "code",
337
+ "execution_count": 15,
338
+ "id": "b214a4dd",
339
+ "metadata": {},
340
+ "outputs": [],
341
+ "source": [
342
+ "def data_generator(data_keys, mapping, features, tokenizer, max_length, vocab_size, batch_size):\n",
343
+ " X1, X2, y = list(), list(), list()\n",
344
+ " n = 0\n",
345
+ " while 1:\n",
346
+ " for key in data_keys:\n",
347
+ " n += 1\n",
348
+ " captions = mapping[key]\n",
349
+ " for caption in captions:\n",
350
+ " seq = tokenizer.texts_to_sequences([caption])[0]\n",
351
+ " for i in range(1, len(seq)):\n",
352
+ " in_seq, out_seq = seq[:i], seq[i]\n",
353
+ " in_seq = pad_sequences([in_seq], maxlen=max_length)[0]\n",
354
+ " out_seq = to_categorical([out_seq], num_classes=vocab_size)[0]\n",
355
+ " X1.append(features[key][0])\n",
356
+ " X2.append(in_seq)\n",
357
+ " y.append(out_seq)\n",
358
+ " if n == batch_size:\n",
359
+ " X1, X2, y = np.array(X1), np.array(X2), np.array(y)\n",
360
+ " yield [X1, X2], y\n",
361
+ " X1, X2, y = list(), list(), list()\n",
362
+ " n = 0"
363
+ ]
364
+ },
365
+ {
366
+ "cell_type": "code",
367
+ "execution_count": 16,
368
+ "id": "ba019340",
369
+ "metadata": {},
370
+ "outputs": [
371
+ {
372
+ "name": "stdout",
373
+ "output_type": "stream",
374
+ "text": [
375
+ "You must install pydot (`pip install pydot`) and install graphviz (see instructions at https://graphviz.gitlab.io/download/) for plot_model to work.\n"
376
+ ]
377
+ }
378
+ ],
379
+ "source": [
380
+ "inputs1 = Input(shape=(4096,))\n",
381
+ "fe1 = Dropout(0.4)(inputs1)\n",
382
+ "fe2 = Dense(256, activation='relu')(fe1)\n",
383
+ "inputs2 = Input(shape=(max_length,))\n",
384
+ "se1 = Embedding(vocab_size, 256, mask_zero=True)(inputs2)\n",
385
+ "se2 = Dropout(0.4)(se1)\n",
386
+ "se3 = LSTM(256)(se2)\n",
387
+ "decoder1 = add([fe2, se3])\n",
388
+ "decoder2 = Dense(256, activation='relu')(decoder1)\n",
389
+ "outputs = Dense(vocab_size, activation='softmax')(decoder2)\n",
390
+ "\n",
391
+ "model = Model(inputs=[inputs1, inputs2], outputs=outputs)\n",
392
+ "model.compile(loss='categorical_crossentropy', optimizer='adam')\n",
393
+ "\n",
394
+ "plot_model(model, show_shapes=True)"
395
+ ]
396
+ },
397
+ {
398
+ "cell_type": "code",
399
+ "execution_count": 17,
400
+ "id": "c9cd441e",
401
+ "metadata": {},
402
+ "outputs": [
403
+ {
404
+ "name": "stdout",
405
+ "output_type": "stream",
406
+ "text": [
407
+ "227/227 [==============================] - 634s 3s/step - loss: 5.2148\n",
408
+ "227/227 [==============================] - 552s 2s/step - loss: 3.9993\n",
409
+ "227/227 [==============================] - 547s 2s/step - loss: 3.5808\n",
410
+ "227/227 [==============================] - 565s 2s/step - loss: 3.3151\n",
411
+ "227/227 [==============================] - 583s 3s/step - loss: 3.1139\n",
412
+ "227/227 [==============================] - 563s 2s/step - loss: 2.9658\n",
413
+ "227/227 [==============================] - 563s 2s/step - loss: 2.8508\n",
414
+ "227/227 [==============================] - 562s 2s/step - loss: 2.7600\n",
415
+ "227/227 [==============================] - 570s 3s/step - loss: 2.6801\n",
416
+ "227/227 [==============================] - 564s 2s/step - loss: 2.6098\n",
417
+ "227/227 [==============================] - 564s 2s/step - loss: 2.5561\n",
418
+ "227/227 [==============================] - 568s 3s/step - loss: 2.4974\n",
419
+ "227/227 [==============================] - 575s 3s/step - loss: 2.4453\n",
420
+ "227/227 [==============================] - 572s 3s/step - loss: 2.3967\n",
421
+ "227/227 [==============================] - 576s 3s/step - loss: 2.3553\n",
422
+ "227/227 [==============================] - 570s 3s/step - loss: 2.3203\n",
423
+ "227/227 [==============================] - 570s 3s/step - loss: 2.2833\n",
424
+ "227/227 [==============================] - 560s 2s/step - loss: 2.2474\n",
425
+ "227/227 [==============================] - 559s 2s/step - loss: 2.2182\n",
426
+ "227/227 [==============================] - 561s 2s/step - loss: 2.1891\n"
427
+ ]
428
+ }
429
+ ],
430
+ "source": [
431
+ "epochs = 20\n",
432
+ "batch_size = 32\n",
433
+ "steps = len(train)\n",
434
+ "\n",
435
+ "for i in range(epochs):\n",
436
+ " generator = data_generator(train, ma, features, tokenizer, max_length, vocab_size, batch_size)\n",
437
+ " model.fit(generator, epochs=1, steps_per_epoch=steps, verbose=1)"
438
+ ]
439
+ },
440
+ {
441
+ "cell_type": "code",
442
+ "execution_count": 19,
443
+ "id": "3e22a08f",
444
+ "metadata": {},
445
+ "outputs": [],
446
+ "source": [
447
+ "model.save(work+'/image_caption.h5')"
448
+ ]
449
+ },
450
+ {
451
+ "cell_type": "code",
452
+ "execution_count": 20,
453
+ "id": "8d6cae78",
454
+ "metadata": {},
455
+ "outputs": [],
456
+ "source": [
457
+ "def idx_word(integer,tok):\n",
458
+ " for word,index in tok.word_index.items():\n",
459
+ " if index== integer:\n",
460
+ " return word\n",
461
+ " return none"
462
+ ]
463
+ },
464
+ {
465
+ "cell_type": "code",
466
+ "execution_count": 25,
467
+ "id": "68502106",
468
+ "metadata": {},
469
+ "outputs": [],
470
+ "source": [
471
+ "def predict_caption(model,image,tok,max_len):\n",
472
+ " in_text=\"startseq\"\n",
473
+ " for i in range(max_len):\n",
474
+ " seq=tok.texts_to_sequences([in_text])[0]\n",
475
+ " seq=pad_sequences([seq],max_len)\n",
476
+ " yhat = model.predict([image, seq], verbose=0)\n",
477
+ " yhat = np.argmax(yhat)\n",
478
+ " word = idx_word(yhat, tok)\n",
479
+ " if word is None:\n",
480
+ " break\n",
481
+ " in_text += \" \" + word\n",
482
+ " if word == 'endseq':\n",
483
+ " break\n",
484
+ " return in_text"
485
+ ]
486
+ },
487
+ {
488
+ "cell_type": "code",
489
+ "execution_count": null,
490
+ "id": "d6fa2905",
491
+ "metadata": {},
492
+ "outputs": [
493
+ {
494
+ "data": {
495
+ "application/vnd.jupyter.widget-view+json": {
496
+ "model_id": "cebaf5ee07d54f4bb56ce83763063629",
497
+ "version_major": 2,
498
+ "version_minor": 0
499
+ },
500
+ "text/plain": [
501
+ " 0%| | 0/810 [00:00<?, ?it/s]"
502
+ ]
503
+ },
504
+ "metadata": {},
505
+ "output_type": "display_data"
506
+ }
507
+ ],
508
+ "source": [
509
+ "from nltk.translate.bleu_score import corpus_bleu\n",
510
+ "actual, predicted = list(), list()\n",
511
+ "for key in tqdm(test):\n",
512
+ " captions = mapping[key]\n",
513
+ " y_pred = predict_caption(model, features[key], tokenizer, max_length) \n",
514
+ " actual_captions = [caption.split() for caption in captions]\n",
515
+ " y_pred = y_pred.split()\n",
516
+ " # append to the list\n",
517
+ " actual.append(actual_captions)\n",
518
+ " predicted.append(y_pred)\n",
519
+ "print(\"BLEU-1: %f\" % corpus_bleu(actual, predicted, weights=(1.0, 0, 0, 0)))\n",
520
+ "print(\"BLEU-2: %f\" % corpus_bleu(actual, predicted, weights=(0.5, 0.5, 0, 0)))"
521
+ ]
522
+ },
523
+ {
524
+ "cell_type": "code",
525
+ "execution_count": null,
526
+ "id": "468e17a6",
527
+ "metadata": {},
528
+ "outputs": [],
529
+ "source": [
530
+ "from PIL import Image\n",
531
+ "import matplotlib.pyplot as plt\n",
532
+ "def generate_caption(image_name):\n",
533
+ " image_id = image_name.split('.')[0]\n",
534
+ " img_path = os.path.join(base, \"Images\", image_name)\n",
535
+ " image = Image.open(img_path)\n",
536
+ " captions = mapping[image_id]\n",
537
+ " print('---------------------Actual---------------------')\n",
538
+ " for caption in captions:\n",
539
+ " print(caption)\n",
540
+ " # predict the caption\n",
541
+ " y_pred = predict_caption(model, features[image_id], tokenizer, max_length)\n",
542
+ " print('--------------------Predicted--------------------')\n",
543
+ " print(y_pred)\n",
544
+ " plt.imshow(image)"
545
+ ]
546
+ },
547
+ {
548
+ "cell_type": "code",
549
+ "execution_count": null,
550
+ "id": "b66d1b91",
551
+ "metadata": {},
552
+ "outputs": [],
553
+ "source": []
554
+ },
555
+ {
556
+ "cell_type": "code",
557
+ "execution_count": null,
558
+ "id": "30bf4acd",
559
+ "metadata": {},
560
+ "outputs": [],
561
+ "source": []
562
+ },
563
+ {
564
+ "cell_type": "code",
565
+ "execution_count": null,
566
+ "id": "76d5e2af",
567
+ "metadata": {},
568
+ "outputs": [],
569
+ "source": []
570
+ },
571
+ {
572
+ "cell_type": "code",
573
+ "execution_count": null,
574
+ "id": "d735bdc1",
575
+ "metadata": {},
576
+ "outputs": [],
577
+ "source": []
578
+ },
579
+ {
580
+ "cell_type": "code",
581
+ "execution_count": 7,
582
+ "id": "cc4d2af9",
583
+ "metadata": {},
584
+ "outputs": [],
585
+ "source": [
586
+ "with open(os.path.join(base,\"captions.txt\"),\"r\") as f:\n",
587
+ " next(f)\n",
588
+ " caption_data=f.read()"
589
+ ]
590
+ },
591
+ {
592
+ "cell_type": "code",
593
+ "execution_count": 8,
594
+ "id": "ddb5ee13",
595
+ "metadata": {},
596
+ "outputs": [
597
+ {
598
+ "data": {
599
+ "application/vnd.jupyter.widget-view+json": {
600
+ "model_id": "c26d20eded654d9a82beaad96d6fcb6b",
601
+ "version_major": 2,
602
+ "version_minor": 0
603
+ },
604
+ "text/plain": [
605
+ " 0%| | 0/40456 [00:00<?, ?it/s]"
606
+ ]
607
+ },
608
+ "metadata": {},
609
+ "output_type": "display_data"
610
+ }
611
+ ],
612
+ "source": [
613
+ "ma={}\n",
614
+ "data=caption_data.split(\"\\n\")\n",
615
+ "for line in tqdm(data):\n",
616
+ " mapp=line.split(\",\")\n",
617
+ " if len(mapp)<2:\n",
618
+ " continue\n",
619
+ " im_id=mapp[0]\n",
620
+ " cap=mapp[1]\n",
621
+ " cap=\"\".join(cap)\n",
622
+ " im_id=im_id.split(\".\")[0]\n",
623
+ " if im_id not in ma:\n",
624
+ " ma[im_id]=[]\n",
625
+ " ma[im_id].append(cap)"
626
+ ]
627
+ },
628
+ {
629
+ "cell_type": "code",
630
+ "execution_count": 9,
631
+ "id": "05cab232",
632
+ "metadata": {},
633
+ "outputs": [],
634
+ "source": [
635
+ "def process_text(cap):\n",
636
+ " cap=cap.lower()\n",
637
+ " cap=cap.replace('[^a-z]',\"\")\n",
638
+ " cap=cap.replace('\\s+',\" \")\n",
639
+ " cap=\"[start] \"+\" \".join([word for word in cap.split(\" \") if len(word)>1])+\" [end]\"\n",
640
+ " return cap"
641
+ ]
642
+ },
643
+ {
644
+ "cell_type": "code",
645
+ "execution_count": 10,
646
+ "id": "f75f26df",
647
+ "metadata": {},
648
+ "outputs": [],
649
+ "source": [
650
+ "def clean(ma):\n",
651
+ " for key, cap in ma.items():\n",
652
+ " for i in range(len(cap)):\n",
653
+ " cap[i]=process_text(cap[i])\n"
654
+ ]
655
+ },
656
+ {
657
+ "cell_type": "code",
658
+ "execution_count": 11,
659
+ "id": "15693ddd",
660
+ "metadata": {},
661
+ "outputs": [
662
+ {
663
+ "data": {
664
+ "text/plain": [
665
+ "['A child in a pink dress is climbing up a set of stairs in an entry way .',\n",
666
+ " 'A girl going into a wooden building .',\n",
667
+ " 'A little girl climbing into a wooden playhouse .',\n",
668
+ " 'A little girl climbing the stairs to her playhouse .',\n",
669
+ " 'A little girl in a pink dress going into a wooden cabin .']"
670
+ ]
671
+ },
672
+ "execution_count": 11,
673
+ "metadata": {},
674
+ "output_type": "execute_result"
675
+ }
676
+ ],
677
+ "source": [
678
+ "ma[\"1000268201_693b08cb0e\"] # just a check before "
679
+ ]
680
+ },
681
+ {
682
+ "cell_type": "code",
683
+ "execution_count": 12,
684
+ "id": "defc5403",
685
+ "metadata": {},
686
+ "outputs": [
687
+ {
688
+ "data": {
689
+ "text/plain": [
690
+ "['[start] child in pink dress is climbing up set of stairs in an entry way [end]',\n",
691
+ " '[start] girl going into wooden building [end]',\n",
692
+ " '[start] little girl climbing into wooden playhouse [end]',\n",
693
+ " '[start] little girl climbing the stairs to her playhouse [end]',\n",
694
+ " '[start] little girl in pink dress going into wooden cabin [end]']"
695
+ ]
696
+ },
697
+ "execution_count": 12,
698
+ "metadata": {},
699
+ "output_type": "execute_result"
700
+ }
701
+ ],
702
+ "source": [
703
+ "clean(ma)\n",
704
+ "ma[\"1000268201_693b08cb0e\"]"
705
+ ]
706
+ },
707
+ {
708
+ "cell_type": "code",
709
+ "execution_count": 13,
710
+ "id": "f5913f53",
711
+ "metadata": {},
712
+ "outputs": [],
713
+ "source": [
714
+ "all_cap=[]\n",
715
+ "for key in ma.keys():\n",
716
+ " for cap in ma[key]:\n",
717
+ " all_cap.append(cap)"
718
+ ]
719
+ },
720
+ {
721
+ "cell_type": "code",
722
+ "execution_count": 14,
723
+ "id": "84d681f2",
724
+ "metadata": {},
725
+ "outputs": [
726
+ {
727
+ "data": {
728
+ "text/plain": [
729
+ "40455"
730
+ ]
731
+ },
732
+ "execution_count": 14,
733
+ "metadata": {},
734
+ "output_type": "execute_result"
735
+ }
736
+ ],
737
+ "source": [
738
+ "len(all_cap)"
739
+ ]
740
+ },
741
+ {
742
+ "cell_type": "code",
743
+ "execution_count": 15,
744
+ "id": "4dbe92b1",
745
+ "metadata": {},
746
+ "outputs": [
747
+ {
748
+ "data": {
749
+ "text/plain": [
750
+ "8311"
751
+ ]
752
+ },
753
+ "execution_count": 15,
754
+ "metadata": {},
755
+ "output_type": "execute_result"
756
+ }
757
+ ],
758
+ "source": [
759
+ "tok=Tokenizer()\n",
760
+ "tok.fit_on_texts(all_cap)\n",
761
+ "vocab_size=len(tok.word_index)+1\n",
762
+ "vocab_size"
763
+ ]
764
+ },
765
+ {
766
+ "cell_type": "code",
767
+ "execution_count": 16,
768
+ "id": "776312f5",
769
+ "metadata": {},
770
+ "outputs": [
771
+ {
772
+ "data": {
773
+ "text/plain": [
774
+ "31"
775
+ ]
776
+ },
777
+ "execution_count": 16,
778
+ "metadata": {},
779
+ "output_type": "execute_result"
780
+ }
781
+ ],
782
+ "source": [
783
+ "max_len=max(len(cap.split())for cap in all_cap)\n",
784
+ "max_len"
785
+ ]
786
+ },
787
+ {
788
+ "cell_type": "code",
789
+ "execution_count": 17,
790
+ "id": "57a14f3f",
791
+ "metadata": {},
792
+ "outputs": [],
793
+ "source": [
794
+ "image_ids=list(ma.keys())\n",
795
+ "split=int(len(image_ids)*0.90)\n",
796
+ "train=image_ids[:split]\n",
797
+ "test=image_ids[split:]"
798
+ ]
799
+ },
800
+ {
801
+ "cell_type": "code",
802
+ "execution_count": 18,
803
+ "id": "69b7ff8a",
804
+ "metadata": {},
805
+ "outputs": [
806
+ {
807
+ "data": {
808
+ "text/plain": [
809
+ "7281"
810
+ ]
811
+ },
812
+ "execution_count": 18,
813
+ "metadata": {},
814
+ "output_type": "execute_result"
815
+ }
816
+ ],
817
+ "source": [
818
+ "len(train)"
819
+ ]
820
+ },
821
+ {
822
+ "cell_type": "code",
823
+ "execution_count": 19,
824
+ "id": "378f6cb7",
825
+ "metadata": {},
826
+ "outputs": [],
827
+ "source": [
828
+ "def data_gen(data_keys,ma,fs,tok,max_len,vocab_size,batch_size):\n",
829
+ " x1,x2,y=list(),list(),list()\n",
830
+ " n=0;\n",
831
+ " while True:\n",
832
+ " for key in data_keys:\n",
833
+ " n+=1\n",
834
+ " cap=ma[key]\n",
835
+ " for cap_i in cap:\n",
836
+ " seq=tok.texts_to_sequences([cap_i])[0]\n",
837
+ " for i in range(len(seq)):\n",
838
+ " in_seq,out_seq=seq[:i],seq[i]\n",
839
+ " in_seq=pad_sequences([in_seq],maxlen=max_len)[0]\n",
840
+ " out_seq=to_categorical([out_seq],num_classes=vocab_size)[0]\n",
841
+ " x1.append(fs[key][0])\n",
842
+ " x2.append(in_seq)\n",
843
+ " y.append(out_seq)\n",
844
+ " if n==batch_size:\n",
845
+ " x1=np.array(x1)\n",
846
+ " x2=np.array(x2)\n",
847
+ " y=np.array(y)\n",
848
+ " yield[x1,x2],y\n",
849
+ " x1,x2,y=list(),list(),list()\n",
850
+ " n=0"
851
+ ]
852
+ },
853
+ {
854
+ "cell_type": "code",
855
+ "execution_count": 20,
856
+ "id": "f5f13047",
857
+ "metadata": {},
858
+ "outputs": [
859
+ {
860
+ "name": "stdout",
861
+ "output_type": "stream",
862
+ "text": [
863
+ "You must install pydot (`pip install pydot`) and install graphviz (see instructions at https://graphviz.gitlab.io/download/) for plot_model to work.\n"
864
+ ]
865
+ }
866
+ ],
867
+ "source": [
868
+ "inputs1=Input(shape=(4096,))\n",
869
+ "fe1=Dropout(0.4)(inputs1)\n",
870
+ "fe2=Dense(256,activation='relu')(fe1)\n",
871
+ "inputs2=Input(shape=(max_len,))\n",
872
+ "se1=Embedding(vocab_size,256,mask_zero=True)(inputs2)\n",
873
+ "se2=Dropout(0.4)(se1)\n",
874
+ "se3=LSTM(256)(se2)\n",
875
+ "\n",
876
+ "decoder1=add([fe2,se3])\n",
877
+ "decoder2=Dense(256,activation='relu')(decoder1)\n",
878
+ "outputs=Dense(vocab_size,activation='softmax')(decoder2)\n",
879
+ "\n",
880
+ "model=Model(inputs=[inputs1,inputs2],outputs=outputs)\n",
881
+ "model.compile(loss=\"categorical_crossentropy\",optimizer='adam')\n",
882
+ "\n",
883
+ "\n",
884
+ "plot_model(model,show_shapes=True)"
885
+ ]
886
+ },
887
+ {
888
+ "cell_type": "code",
889
+ "execution_count": null,
890
+ "id": "d63d6d4b",
891
+ "metadata": {},
892
+ "outputs": [
893
+ {
894
+ "name": "stdout",
895
+ "output_type": "stream",
896
+ "text": [
897
+ "\r",
898
+ " 1/7281 [..............................] - ETA: 38:03:57 - loss: 9.0597"
899
+ ]
900
+ }
901
+ ],
902
+ "source": [
903
+ "epochs=15\n",
904
+ "batch_size=64\n",
905
+ "steps=len(train)\n",
906
+ "for i in range(epochs):\n",
907
+ " generator=data_gen(train,ma,fs,tok,max_len,vocab_size,batch_size)\n",
908
+ " model.fit(generator,epochs=1,steps_per_epoch=steps,verbose=1)"
909
+ ]
910
+ },
911
+ {
912
+ "cell_type": "code",
913
+ "execution_count": 18,
914
+ "id": "3322120d",
915
+ "metadata": {},
916
+ "outputs": [],
917
+ "source": [
918
+ "model.save(work+'/image_caption.h5')"
919
+ ]
920
+ },
921
+ {
922
+ "cell_type": "code",
923
+ "execution_count": null,
924
+ "id": "303f7a8e",
925
+ "metadata": {},
926
+ "outputs": [],
927
+ "source": [
928
+ "def idx_word(integer,tok):\n",
929
+ " for word,index in tok.word_index.items():\n",
930
+ " if index== integer:\n",
931
+ " return word\n",
932
+ " return none"
933
+ ]
934
+ },
935
+ {
936
+ "cell_type": "code",
937
+ "execution_count": null,
938
+ "id": "541d09e8",
939
+ "metadata": {},
940
+ "outputs": [],
941
+ "source": [
942
+ "def predict_caption(model,image,tok,max_len):\n",
943
+ " in_text=\"[start]\"\n",
944
+ " for i in range(max_len):\n",
945
+ " seq=tok.texts_to_sequences([in_text])[0]\n",
946
+ " seq=pad_sequences([seq],max_len)[0]\n",
947
+ " yhat"
948
+ ]
949
+ }
950
+ ],
951
+ "metadata": {
952
+ "kernelspec": {
953
+ "display_name": "Python 3 (ipykernel)",
954
+ "language": "python",
955
+ "name": "python3"
956
+ },
957
+ "language_info": {
958
+ "codemirror_mode": {
959
+ "name": "ipython",
960
+ "version": 3
961
+ },
962
+ "file_extension": ".py",
963
+ "mimetype": "text/x-python",
964
+ "name": "python",
965
+ "nbconvert_exporter": "python",
966
+ "pygments_lexer": "ipython3",
967
+ "version": "3.9.13"
968
+ }
969
+ },
970
+ "nbformat": 4,
971
+ "nbformat_minor": 5
972
+ }