Spaces:
Running
Running
Create function to call models
Browse files
model.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pickle
|
| 2 |
+
import tensorflow as tf
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
MAX_LENGTH = 40
|
| 7 |
+
BATCH_SIZE = 32
|
| 8 |
+
BUFFER_SIZE = 1000
|
| 9 |
+
EMBEDDING_DIM = 512
|
| 10 |
+
UNITS = 512
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# LOADING DATA
|
| 14 |
+
vocab = pickle.load(open('vocabulary/vocab_coco.file', 'rb'))
|
| 15 |
+
|
| 16 |
+
tokenizer = tf.keras.layers.TextVectorization(
|
| 17 |
+
standardize = None,
|
| 18 |
+
output_sequence_length = MAX_LENGTH,
|
| 19 |
+
vocabulary = vocab
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
idx2word = tf.keras.layers.StringLookup(
|
| 23 |
+
mask_token = "",
|
| 24 |
+
vocabulary = tokenizer.get_vocabulary(),
|
| 25 |
+
invert = True
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
def load_image_from_path(img_path):
|
| 29 |
+
img = tf.io.read_file(img_path)
|
| 30 |
+
img = tf.io.decode_jpeg(img, channels=3)
|
| 31 |
+
img = tf.keras.layers.Resizing(299, 299)(img)
|
| 32 |
+
img = tf.keras.applications.inception_v3.preprocess_input(img)
|
| 33 |
+
return img
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def generate_caption(img, caption_model, add_noise=False):
|
| 37 |
+
if isinstance(img, str):
|
| 38 |
+
img = load_image_from_path(img)
|
| 39 |
+
|
| 40 |
+
if add_noise == True:
|
| 41 |
+
noise = tf.random.normal(img.shape)*0.1
|
| 42 |
+
img = (img + noise)
|
| 43 |
+
img = (img - tf.reduce_min(img))/(tf.reduce_max(img) - tf.reduce_min(img))
|
| 44 |
+
|
| 45 |
+
img = tf.expand_dims(img, axis=0)
|
| 46 |
+
img_embed = caption_model.cnn_model(img)
|
| 47 |
+
img_encoded = caption_model.encoder(img_embed, training=False)
|
| 48 |
+
|
| 49 |
+
y_inp = '[start]'
|
| 50 |
+
for i in range(MAX_LENGTH-1):
|
| 51 |
+
tokenized = tokenizer([y_inp])[:, :-1]
|
| 52 |
+
mask = tf.cast(tokenized != 0, tf.int32)
|
| 53 |
+
pred = caption_model.decoder(
|
| 54 |
+
tokenized, img_encoded, training=False, mask=mask)
|
| 55 |
+
|
| 56 |
+
pred_idx = np.argmax(pred[0, i, :])
|
| 57 |
+
pred_word = idx2word(pred_idx).numpy().decode('utf-8')
|
| 58 |
+
if pred_word == '[end]':
|
| 59 |
+
break
|
| 60 |
+
|
| 61 |
+
y_inp += ' ' + pred_word
|
| 62 |
+
|
| 63 |
+
y_inp = y_inp.replace('[start] ', '')
|
| 64 |
+
return y_inp
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def get_caption_model():
|
| 68 |
+
encoder = TransformerEncoderLayer(EMBEDDING_DIM, 1)
|
| 69 |
+
decoder = TransformerDecoderLayer(EMBEDDING_DIM, UNITS, 8)
|
| 70 |
+
|
| 71 |
+
cnn_model = CNN_Encoder()
|
| 72 |
+
|
| 73 |
+
caption_model = ImageCaptioningModel(
|
| 74 |
+
cnn_model=cnn_model, encoder=encoder, decoder=decoder, image_aug=None,
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
def call_fn(batch, training):
|
| 78 |
+
return batch
|
| 79 |
+
|
| 80 |
+
caption_model.call = call_fn
|
| 81 |
+
sample_x, sample_y = tf.random.normal((1, 299, 299, 3)), tf.zeros((1, 40))
|
| 82 |
+
|
| 83 |
+
caption_model((sample_x, sample_y))
|
| 84 |
+
|
| 85 |
+
sample_img_embed = caption_model.cnn_model(sample_x)
|
| 86 |
+
sample_enc_out = caption_model.encoder(sample_img_embed, training=False)
|
| 87 |
+
caption_model.decoder(sample_y, sample_enc_out, training=False)
|
| 88 |
+
|
| 89 |
+
try:
|
| 90 |
+
caption_model.load_weights('models/trained_coco_weights.h5')
|
| 91 |
+
except FileNotFoundError:
|
| 92 |
+
caption_model.load_weights('image-caption-generator/models/trained_coco_weights.h5')
|
| 93 |
+
|
| 94 |
+
return caption_model
|