peshk1n's picture
Update app.py
e90615d verified
from tensorflow import keras
import numpy as np
import tensorflow as tf
from tensorflow import data as tf_data
from tensorflow import image as tf_image
from tensorflow import io as tf_io
from PIL import Image
import json
from tensorflow.keras import layers, Model
import string
from transformers import TFAutoModel
import gradio as gr
import os
import numpy as np
from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input
from tensorflow.keras.preprocessing import image
from tensorflow.keras.models import Model
os.environ["KERAS_BACKEND"] = "tensorflow"
start_token = "[BOS]"
end_token = "[EOS]"
cls_token = "[CLS]"
data_dir = '/content/coco'
data_type_train = 'train2014'
data_type_val = 'val2014'
vocab_size = 24000
sentence_length = 20
batch_size = 128
img_size = 224
proj_dim = 192
dropout_rate = 0.1
num_patches = 14
patch_size = img_size // num_patches
num_heads = 3
num_layers = 6
attn_pool_dim = proj_dim
attn_pool_heads = num_heads
cap_query_num = 128
rnn_embedding_dim = 256
rnn_proj_dim = 512
with open('vocabs/word_index.json', 'r', encoding='utf-8') as f:
word_index = {np.str_(word): np.int64(idx) for word, idx in json.load(f).items()}
with open('vocabs/index_word.json', 'r', encoding='utf-8') as f:
index_word = {np.int64(idx): np.str_(word) for idx, word in json.load(f).items()}
cls_token_id = word_index[cls_token]
class PositionalEmbedding(layers.Layer):
def __init__(self, sequence_length, input_dim, output_dim, **kwargs):
super().__init__(**kwargs)
self.sequence_length = sequence_length
self.input_dim = input_dim
self.output_dim = output_dim
self.token_embeddings = layers.Embedding(
input_dim=input_dim, output_dim=output_dim
)
self.position_embeddings = layers.Embedding(
input_dim=sequence_length, output_dim=output_dim
)
def call(self, inputs):
positions = tf.range(start=0, limit=self.sequence_length, delta=1)
embedded_tokens = self.token_embeddings(inputs)
embedded_positions = self.position_embeddings(positions)
output = embedded_tokens + embedded_positions
return output
class AttentionalPooling(layers.Layer):
def __init__(self, embed_dim, num_heads=6):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.multihead_attention = layers.MultiHeadAttention(num_heads=self.num_heads, key_dim=self.embed_dim)
self.norm = layers.LayerNormalization()
def call(self, features, query):
attn_output = self.multihead_attention(
query=query,
value=features,
key=features
)
return self.norm(attn_output)
class TransformerBlock(layers.Layer):
def __init__(self, embed_dim, dense_dim, num_heads, dropout_rate=0.1, ln_epsilon=1e-6, is_multimodal=False, **kwargs):
super().__init__(**kwargs)
self.embed_dim = embed_dim
self.dense_dim = dense_dim
self.num_heads = num_heads
self.dropout_rate = dropout_rate
self.ln_epsilon = ln_epsilon
self.self_attention = layers.MultiHeadAttention(
num_heads=self.num_heads,
key_dim=self.embed_dim,
dropout=self.dropout_rate
)
if is_multimodal:
self.norm2 = layers.LayerNormalization(epsilon=self.ln_epsilon)
self.dropout2 = layers.Dropout(self.dropout_rate)
self.cross_attention = layers.MultiHeadAttention(
num_heads=self.num_heads,
key_dim=self.embed_dim,
dropout=self.dropout_rate
)
self.dense_proj = tf.keras.Sequential([
layers.Dense(self.dense_dim, activation="gelu"),
layers.Dropout(self.dropout_rate),
layers.Dense(self.embed_dim)
])
self.norm1 = layers.LayerNormalization(epsilon=self.ln_epsilon)
self.norm3 = layers.LayerNormalization(epsilon=self.ln_epsilon)
self.dropout1 = layers.Dropout(self.dropout_rate)
self.dropout3 = layers.Dropout(self.dropout_rate)
def get_causal_attention_mask(self, inputs):
seq_len = tf.shape(inputs)[1]
causal_mask = tf.linalg.band_part(tf.ones((seq_len, seq_len), tf.bool), -1, 0)
return tf.expand_dims(causal_mask, 0)
def get_combined_mask(self, causal_mask, padding_mask):
padding_mask = tf.cast(padding_mask, tf.bool)
padding_mask = tf.expand_dims(padding_mask, 1)
return causal_mask & padding_mask
def call(self, inputs, encoder_outputs=None, mask=None):
att_mask = self.get_causal_attention_mask(inputs)
if mask is not None:
att_mask = self.get_combined_mask(att_mask, mask)
x = self.norm1(inputs)
attention_output_1 = self.self_attention(
query=x, key=x, value=x, attention_mask=att_mask
)
attention_output_1 = self.dropout1(attention_output_1)
x = x + attention_output_1
if encoder_outputs is not None:
x_norm = self.norm2(x)
attention_output_2 = self.cross_attention(
query=x_norm, key=encoder_outputs, value=encoder_outputs
)
attention_output_2 = self.dropout2(attention_output_2)
x = x + attention_output_2
x_norm = self.norm3(x)
proj_output = self.dense_proj(x_norm)
proj_output = self.dropout3(proj_output)
return x + proj_output
class UnimodalTextDecoder(layers.Layer):
def __init__(self, embed_dim, dense_dim, num_heads, dropout_rate=0.1, ln_epsilon=1e-6, num_layers=4, **kwargs):
super().__init__()
self.embed_dim = embed_dim
self.dense_dim = dense_dim
self.num_heads = num_heads
self.dropout_rate = dropout_rate
self.ln_epsilon = ln_epsilon
self.num_layers = num_layers
self.layers = [
TransformerBlock(self.embed_dim, self.dense_dim, self.num_heads, self.dropout_rate, self.ln_epsilon, is_multimodal=False)
for _ in range(self.num_layers)
]
self.norm = tf.keras.layers.LayerNormalization()
def call(self, x, mask=None):
for layer in self.layers:
x = layer(inputs=x, mask=mask)
return self.norm(x)
class MultimodalTextDecoder(layers.Layer):
def __init__(self, embed_dim, dense_dim, num_heads, dropout_rate=0.1, ln_epsilon=1e-6, num_layers=4, **kwargs):
super().__init__()
self.embed_dim = embed_dim
self.dense_dim = dense_dim
self.num_heads = num_heads
self.dropout_rate = dropout_rate
self.ln_epsilon = ln_epsilon
self.num_layers = num_layers
self.layers = [
TransformerBlock(self.embed_dim, self.dense_dim, self.num_heads, self.dropout_rate, self.ln_epsilon, is_multimodal=True)
for _ in range(self.num_layers)
]
self.norm = tf.keras.layers.LayerNormalization()
def call(self, x, encoder_outputs, mask=None):
for layer in self.layers:
x = layer(inputs=x, encoder_outputs=encoder_outputs, mask=mask)
return self.norm(x)
class EmbedToLatents(layers.Layer):
def __init__(self, dim_latents, **kwargs):
super(EmbedToLatents, self).__init__(**kwargs)
self.dim_latents = dim_latents
self.to_latents = layers.Dense(
self.dim_latents,
use_bias=False
)
def call(self, inputs):
latents = self.to_latents(inputs)
return tf.math.l2_normalize(latents, axis=-1)
class Perplexity(tf.keras.metrics.Metric):
def __init__(self, name='perplexity', **kwargs):
super().__init__(name=name, **kwargs)
self.total_loss = self.add_weight(name='total_loss', initializer='zeros')
self.total_tokens = self.add_weight(name='total_tokens', initializer='zeros')
def update_state(self, y_true, y_pred, sample_weight=None):
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')
loss = loss_fn(y_true, y_pred)
mask = tf.cast(tf.not_equal(y_true, 0), tf.float32)
loss = tf.reduce_sum(loss * mask)
num_tokens = tf.reduce_sum(mask)
self.total_loss.assign_add(loss)
self.total_tokens.assign_add(num_tokens)
def result(self):
return tf.exp(self.total_loss / self.total_tokens)
def reset_states(self):
self.total_loss.assign(0.0)
self.total_tokens.assign(0.0)
model_name = "WinKawaks/vit-tiny-patch16-224"
vit_tiny_model = TFAutoModel.from_pretrained(model_name)
vit_tiny_model.trainable = True
for layer in vit_tiny_model.layers:
layer.trainable = True
class CoCaEncoder(tf.keras.Model):
def __init__(self,
vit, **kwargs):
super().__init__(**kwargs)
self.vit = vit
self.contrastive_pooling = AttentionalPooling(attn_pool_dim, attn_pool_heads)
self.caption_pooling = AttentionalPooling(attn_pool_dim, attn_pool_heads)
self.con_query = tf.Variable(
initial_value=tf.random.normal([1, 1, proj_dim]),
trainable=True,
name="con_query"
)
self.cap_query = tf.Variable(
initial_value=tf.random.normal([1, cap_query_num, proj_dim]),
trainable=True,
name="cap_query"
)
def call(self, input, training=False):
img_feature = self.vit(input).last_hidden_state
batch_size = tf.shape(img_feature)[0]
con_query_b = tf.repeat(self.con_query, repeats=batch_size, axis=0)
cap_query_b = tf.repeat(self.cap_query, repeats=batch_size, axis=0)
con_feature = self.contrastive_pooling(img_feature, con_query_b)
cap_feature = self.caption_pooling(img_feature, cap_query_b)
return con_feature, cap_feature
class CoCaDecoder(tf.keras.Model):
def __init__(self,
cls_token_id,
num_heads,
num_layers,
**kwargs):
super().__init__(**kwargs)
self.cls_token_id = cls_token_id
self.pos_emb = PositionalEmbedding(sentence_length, vocab_size, proj_dim)
self.unimodal_decoder = UnimodalTextDecoder(
proj_dim, proj_dim * 4, num_heads, dropout_rate, num_layers=num_layers
)
self.multimodal_decoder = MultimodalTextDecoder(
proj_dim, proj_dim * 4, num_heads, dropout_rate, num_layers=num_layers
)
self.to_logits = tf.keras.layers.Dense(
vocab_size,
name='logits_projection'
)
self.norm = layers.LayerNormalization()
def call(self, inputs, training=False):
input_text, cap_feature = inputs
batch_size = tf.shape(input_text)[0]
cls_tokens = tf.fill([batch_size, 1], tf.cast(self.cls_token_id, input_text.dtype))
ids = tf.concat([input_text, cls_tokens], axis=1)
text_mask = tf.not_equal(input_text, 0)
cls_mask = tf.zeros([batch_size, 1], dtype=text_mask.dtype)
extended_mask = tf.concat([text_mask, cls_mask], axis=1)
txt_embs = self.pos_emb(ids)
unimodal_out = self.unimodal_decoder(txt_embs, mask=extended_mask)
multimodal_out = self.multimodal_decoder(unimodal_out[:, :-1, :], cap_feature, mask=text_mask)
cls_token_feature = self.norm(unimodal_out[:, -1:, :])
multimodal_logits = self.to_logits(multimodal_out)
return cls_token_feature, multimodal_logits
class CoCaModel(tf.keras.Model):
def __init__(self,
vit,
cls_token_id,
num_heads,
num_layers):
super().__init__()
self.encoder = CoCaEncoder(vit, name="coca_encoder")
self.decoder = CoCaDecoder(cls_token_id, num_heads, num_layers, name="coca_decoder")
self.img_to_latents = EmbedToLatents(proj_dim)
self.text_to_latents = EmbedToLatents(proj_dim)
self.pad_id = 0
self.temperature = 0.07
self.caption_loss_weight = 1.0
self.contrastive_loss_weight = 1.0
self.perplexity = Perplexity()
def call(self, inputs, training=False):
image, text = inputs
con_feature, cap_feature = self.encoder(image)
cls_token_feature, multimodal_logits = self.decoder([text, cap_feature])
return con_feature, cls_token_feature, multimodal_logits
def compile(self, optimizer):
super().compile()
self.optimizer = optimizer
def compute_caption_loss(self, multimodal_out, caption_target):
caption_loss = tf.keras.losses.sparse_categorical_crossentropy(
caption_target, multimodal_out, from_logits=True, ignore_class=self.pad_id)
return tf.reduce_mean(caption_loss)
def compute_contrastive_loss(self, con_feature, cls_feature):
text_embeds = tf.squeeze(cls_feature, axis=1)
image_embeds = tf.squeeze(con_feature, axis=1)
text_latents = self.text_to_latents(text_embeds)
image_latents = self.img_to_latents(image_embeds)
sim = tf.matmul(text_latents, image_latents, transpose_b=True) / self.temperature
batch_size = tf.shape(sim)[0]
contrastive_labels = tf.range(batch_size)
loss1 = tf.keras.losses.sparse_categorical_crossentropy(contrastive_labels, sim, from_logits=True)
loss2 = tf.keras.losses.sparse_categorical_crossentropy(contrastive_labels, tf.transpose(sim), from_logits=True)
contrastive_loss = tf.reduce_mean((loss1 + loss2) * 0.5)
return contrastive_loss
def train_step(self, data):
(images, caption_input), caption_target = data
with tf.GradientTape() as tape:
con_feature, cls_feature, multimodal_out = self([images, caption_input], training=True)
caption_loss = self.compute_caption_loss(multimodal_out, caption_target)
contrastive_loss = self.compute_contrastive_loss(con_feature, cls_feature)
total_loss = self.caption_loss_weight * caption_loss + self.contrastive_loss_weight * contrastive_loss
gradients = tape.gradient(total_loss, self.trainable_variables)
self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
self.perplexity.update_state(caption_target, multimodal_out)
return {
'total_loss': total_loss,
'caption_loss': caption_loss,
'contrastive_loss': contrastive_loss,
'perplexity': self.perplexity.result()
}
def test_step(self, data):
(images, caption_input), caption_target = data
con_feature, cls_feature, multimodal_out = self([images, caption_input], training=False)
caption_loss = self.compute_caption_loss(multimodal_out, caption_target)
contrastive_loss = self.compute_contrastive_loss(con_feature, cls_feature)
total_loss = self.caption_loss_weight * caption_loss + self.contrastive_loss_weight * contrastive_loss
self.perplexity.update_state(caption_target, multimodal_out)
return {
'total_loss': total_loss,
'caption_loss': caption_loss,
'contrastive_loss': contrastive_loss,
'perplexity': self.perplexity.result()
}
def reset_metrics(self):
self.perplexity.reset_state()
coca_model = CoCaModel(vit_tiny_model, cls_token_id=cls_token_id, num_heads=num_heads, num_layers=num_layers)
dummy_features = tf.zeros((1, 3, img_size, img_size), dtype=tf.float32)
dummy_captions = tf.zeros((1, sentence_length-1), dtype=tf.int64)
_ = coca_model((dummy_features, dummy_captions))
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)
coca_model.compile(optimizer)
save_dir = "models/"
model_name = "coca"
coca_model.load_weights(f"{save_dir}/{model_name}.weights.h5")
img_embed_dim = 2048
reg_count = 7 * 7
base_model = ResNet50(weights='imagenet', include_top=False)
model = Model(inputs=base_model.input, outputs=base_model.output)
def preprocess_image(img):
img = tf.image.resize(img, (img_size, img_size))
img = tf.convert_to_tensor(img)
img = preprocess_input(img)
return np.expand_dims(img, axis=0)
def create_features(img):
img = preprocess_image(img)
features = model.predict(img, verbose=0)
features = features.reshape((1, reg_count, img_embed_dim))
return features
class BahdanauAttention(layers.Layer):
def __init__(self, units, **kwargs):
super().__init__(**kwargs)
self.units = units
self.W1 = layers.Dense(units)
self.W2 = layers.Dense(units)
self.V = layers.Dense(1)
def call(self, features, hidden):
hidden = tf.expand_dims(hidden, 1)
score = self.V(tf.nn.tanh(
self.W1(features) + self.W2(hidden)
))
alpha = tf.nn.softmax(score, axis=1)
context = tf.reduce_sum(alpha * features, axis=1)
return context, alpha
class ImageCaptioningModel(tf.keras.Model):
def __init__(self, vocab_size, max_caption_len, embedding_dim=512, lstm_units=512, dropout_rate=0.5, **kwargs):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.max_caption_len = max_caption_len
self.embedding_dim = embedding_dim
self.lstm_units = lstm_units
self.dropout_rate = dropout_rate
self.embedding = layers.Embedding(vocab_size, embedding_dim)
self.embedding_dropout = layers.Dropout(dropout_rate)
self.lstm = layers.LSTM(lstm_units, return_sequences=True, return_state=True)
self.attention = BahdanauAttention(lstm_units)
self.fc_dropout = layers.Dropout(dropout_rate)
self.fc = layers.Dense(vocab_size, activation='softmax')
self.init_h = layers.Dense(lstm_units, activation='tanh')
self.init_c = layers.Dense(lstm_units)
self.concatenate = layers.Concatenate(axis=-1)
def call(self, inputs):
features, captions = inputs
mean_features = tf.reduce_mean(features, axis=1)
h = self.init_h(mean_features)
c = self.init_c(mean_features)
embeddings = self.embedding(captions)
embeddings = self.embedding_dropout(embeddings)
outputs = []
for t in range(self.max_caption_len):
context, _ = self.attention(features, h)
lstm_input = self.concatenate([embeddings[:, t, :], context])
lstm_input = tf.expand_dims(lstm_input, 1)
output, h, c = self.lstm(lstm_input, initial_state=[h, c])
outputs.append(output)
outputs = tf.concat(outputs, axis=1)
outputs = self.fc_dropout(outputs)
return self.fc(outputs)
rnn_model = ImageCaptioningModel(vocab_size, sentence_length-1, rnn_embedding_dim, rnn_proj_dim)
image_input = np.random.rand(batch_size, reg_count, img_embed_dim).astype(np.float32)
text_input = np.random.randint(0, 10000, size=(batch_size, sentence_length))
_ = rnn_model([image_input, text_input])
rnn_model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
metrics=[Perplexity()]
)
save_dir = "models/"
model_name = "rnn_attn"
rnn_model.load_weights(f"{save_dir}/{model_name}.weights.h5")
beam_width=3
max_length=sentence_length-1
temperature=1.0
image_mean = [0.5, 0.5, 0.5]
image_std = [0.5, 0.5, 0.5]
def load_and_preprocess_image(img):
img = tf.convert_to_tensor(img)
img = tf.image.resize(img, (img_size, img_size))
img = img / 255.0
img = (img - image_mean) / image_std
img = tf.transpose(img, perm=[2, 0, 1])
return np.expand_dims(img, axis=0)
def has_repeated_ngrams(seq, n=2):
ngrams = [tuple(seq[i:i+n]) for i in range(len(seq)-n+1)]
return len(ngrams) != len(set(ngrams))
image_mean = [0.5, 0.5, 0.5]
image_std = [0.5, 0.5, 0.5]
def load_and_preprocess_image(img):
#img = tf.image.decode_jpeg(img, channels=3)
img = tf.convert_to_tensor(img)
img = tf.image.resize(img, (img_size, img_size))
img = img / 255.0
img = (img - image_mean) / image_std
img = tf.transpose(img, perm=[2, 0, 1])
return np.expand_dims(img, axis=0)
# def generate_caption_coca(image):
# img_processed = load_and_preprocess_image(image)
# _, cap_features = coca_model.encoder.predict(img_processed, verbose=0)
# cap_features = cap_features.astype(np.float32)
# start_token_id = word_index[start_token]
# end_token_id = word_index[end_token]
# sequence = [start_token_id]
# text_input = np.zeros((1, sentence_length - 1), dtype=np.float32)
# for t in range(sentence_length - 1):
# text_input[0, :len(sequence)] = sequence
# _, logits = coca_model.decoder.predict(
# [text_input, cap_features],
# verbose=0
# )
# next_token = np.argmax(logits[0, t, :])
# sequence.append(next_token)
# if next_token == end_token_id or len(sequence) >= (sentence_length - 1):
# break
# caption = " ".join(
# [index_word[token] for token in sequence
# if token not in {word_index[start_token], word_index[end_token]}]
# )
# return caption
def generate_caption_coca(image):
img_processed = load_and_preprocess_image(image)
_, cap_features = coca_model.encoder.predict(img_processed, verbose=0)
beams = [([word_index[start_token]], 0.0)]
for _ in range(max_length):
new_beams = []
for seq, log_prob in beams:
if seq[-1] == word_index[end_token]:
new_beams.append((seq, log_prob))
continue
text_input = np.zeros((1, max_length), dtype=np.int32)
text_input[0, :len(seq)] = seq
predictions = coca_model.decoder.predict([text_input, cap_features], verbose=0)
_, logits = predictions
logits = logits[0, len(seq)-1, :]
probs = np.exp(logits - np.max(logits))
probs /= probs.sum()
top_k = np.argsort(-probs)[:beam_width]
for token in top_k:
new_seq = seq + [token]
new_log_prob = (log_prob * len(seq) + np.log(probs[token])) / (len(seq) + 1)
if has_repeated_ngrams(new_seq, n=2):
new_log_prob -= 0.5
new_beams.append((new_seq, new_log_prob))
beams = sorted(new_beams, key=lambda x: x[1], reverse=True)[:beam_width]
if all(beam[0][-1] == word_index[end_token] for beam in beams):
break
best_seq = max(beams, key=lambda x: x[1])[0]
return " ".join(index_word[i] for i in best_seq if i not in {word_index[start_token], word_index[end_token]})
def generate_caption_rnn(image):
image_embedding = create_features(image)
beams = [([word_index[start_token]], 0.0)]
for _ in range(max_length):
new_beams = []
for seq, log_prob in beams:
if seq[-1] == word_index[end_token]:
new_beams.append((seq, log_prob))
continue
text_input = np.zeros((1, max_length), dtype=np.int32)
text_input[0, :len(seq)] = seq
predictions = rnn_model.predict([image_embedding, text_input], verbose=0)
probs = predictions[0, len(seq)-1, :]
probs = probs ** (1 / temperature)
probs /= probs.sum()
top_k = np.argpartition(probs, -beam_width)[-beam_width:]
for token in top_k:
new_seq = seq + [token]
new_log_prob = (log_prob * len(seq) + np.log(probs[token])) / (len(seq) + 1)
if has_repeated_ngrams(new_seq, n=2):
new_log_prob -= 0.5
new_beams.append((new_seq, new_log_prob))
beams = sorted(new_beams, key=lambda x: x[1], reverse=True)[:beam_width]
if all(beam[0][-1] == word_index[end_token] for beam in beams):
break
best_seq = max(beams, key=lambda x: x[1])[0]
return " ".join(index_word[i] for i in best_seq if i not in {word_index[start_token], word_index[end_token]})
def generate_both(image):
caption1 = generate_caption_rnn(image)
caption2 = generate_caption_coca(image)
return f"RNN: {caption1}\n\nCoCa: {caption2}"
interface = gr.Interface(
fn=generate_both,
inputs=gr.Image(type="pil", label="Изображение"),
outputs=gr.Textbox(label="Описания", autoscroll=True, show_copy_button=True),
allow_flagging="never",
submit_btn="Сгенерировать",
clear_btn="Очистить",
deep_link=False
)
with gr.Blocks() as demo:
gr.Markdown("# 🖼️ Генератор описаний к изображениям")
interface.render()
if __name__ == "__main__":
demo.launch(ssr_mode=False, show_api=False)