peshk1n's picture
Update app.py
6b8da2f verified
from tensorflow import keras
import numpy as np
import tensorflow as tf
from tensorflow import data as tf_data
from tensorflow import image as tf_image
from tensorflow import io as tf_io
from PIL import Image
import json
from tensorflow.keras import layers, Model
import string
from transformers import TFAutoModel
import gradio as gr
import os
import numpy as np
from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input
from tensorflow.keras.preprocessing import image
from tensorflow.keras.models import Model
os.environ["KERAS_BACKEND"] = "tensorflow"
# Переменные ================================
start_token = "[BOS]"
end_token = "[EOS]"
cls_token = "[CLS]"
data_dir = '/content/coco'
data_type_train = 'train2014'
data_type_val = 'val2014'
vocab_size = 24000
sentence_length = 20
batch_size = 128
img_size = 224
proj_dim = 192
dropout_rate = 0.1
num_patches = 14
patch_size = img_size // num_patches
num_heads = 3
num_layers = 6
attn_pool_dim = proj_dim
attn_pool_heads = num_heads
cap_query_num = 128
#RNN
rnn_embedding_dim = 256
rnn_proj_dim = 512
# =================================
# Загрузка word_index
with open('vocabs/word_index.json', 'r', encoding='utf-8') as f:
word_index = {np.str_(word): np.int64(idx) for word, idx in json.load(f).items()}
# Загрузка index_word
with open('vocabs/index_word.json', 'r', encoding='utf-8') as f:
index_word = {np.int64(idx): np.str_(word) for idx, word in json.load(f).items()}
cls_token_id = word_index[cls_token]
class PositionalEmbedding(layers.Layer):
def __init__(self, sequence_length, input_dim, output_dim, **kwargs):
super().__init__(**kwargs)
self.sequence_length = sequence_length
self.input_dim = input_dim
self.output_dim = output_dim
self.token_embeddings = layers.Embedding(
input_dim=input_dim, output_dim=output_dim
)
self.position_embeddings = layers.Embedding(
input_dim=sequence_length, output_dim=output_dim
)
def call(self, inputs):
positions = tf.range(start=0, limit=self.sequence_length, delta=1)
embedded_tokens = self.token_embeddings(inputs)
embedded_positions = self.position_embeddings(positions)
output = embedded_tokens + embedded_positions
return output
class AttentionalPooling(tf.keras.layers.Layer):
def __init__(self, embed_dim, num_heads=6):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.multihead_attention = layers.MultiHeadAttention(num_heads=self.num_heads, key_dim=self.embed_dim)
self.norm = layers.LayerNormalization()
def call(self, features, query):
attn_output = self.multihead_attention(
query=query,
value=features,
key=features
)
return self.norm(attn_output)
class TransformerBlock(tf.keras.layers.Layer):
def __init__(self, embed_dim, dense_dim, num_heads, dropout_rate=0.1, ln_epsilon=1e-6, is_multimodal=False, **kwargs):
super().__init__(**kwargs)
self.embed_dim = embed_dim
self.dense_dim = dense_dim
self.num_heads = num_heads
self.dropout_rate = dropout_rate
self.ln_epsilon = ln_epsilon
# Self-Attention
self.self_attention = layers.MultiHeadAttention(
num_heads=self.num_heads,
key_dim=self.embed_dim,
dropout=self.dropout_rate
)
# Cross-Attention
if is_multimodal:
self.norm2 = layers.LayerNormalization(epsilon=self.ln_epsilon)
self.dropout2 = layers.Dropout(self.dropout_rate)
self.cross_attention = layers.MultiHeadAttention(
num_heads=self.num_heads,
key_dim=self.embed_dim,
dropout=self.dropout_rate
)
# Feed-Forward Network
self.dense_proj = tf.keras.Sequential([
layers.Dense(self.dense_dim, activation="gelu"),
layers.Dropout(self.dropout_rate),
layers.Dense(self.embed_dim)
])
# Layer Normalization
self.norm1 = layers.LayerNormalization(epsilon=self.ln_epsilon)
self.norm3 = layers.LayerNormalization(epsilon=self.ln_epsilon)
# Dropout
self.dropout1 = layers.Dropout(self.dropout_rate)
self.dropout3 = layers.Dropout(self.dropout_rate)
def get_causal_attention_mask(self, inputs):
seq_len = tf.shape(inputs)[1]
causal_mask = tf.linalg.band_part(tf.ones((seq_len, seq_len), tf.bool), -1, 0)
return tf.expand_dims(causal_mask, 0)
def get_combined_mask(self, causal_mask, padding_mask):
padding_mask = tf.cast(padding_mask, tf.bool)
padding_mask = tf.expand_dims(padding_mask, 1) # (B, 1, L)
return causal_mask & padding_mask
def call(self, inputs, encoder_outputs=None, mask=None):
att_mask = self.get_causal_attention_mask(inputs)
if mask is not None:
att_mask = self.get_combined_mask(att_mask, mask)
# Self-Attention
x = self.norm1(inputs)
attention_output_1 = self.self_attention(
query=x, key=x, value=x, attention_mask=att_mask
)
attention_output_1 = self.dropout1(attention_output_1)
x = x + attention_output_1 # Add & Norm
# Cross-Attention
if encoder_outputs is not None:
x_norm = self.norm2(x)
attention_output_2 = self.cross_attention(
query=x_norm, key=encoder_outputs, value=encoder_outputs
)
attention_output_2 = self.dropout2(attention_output_2)
x = x + attention_output_2 # Add & Norm
# Feed-Forward Network (FFN)
x_norm = self.norm3(x)
proj_output = self.dense_proj(x_norm)
proj_output = self.dropout3(proj_output)
return x + proj_output # Add & Norm
class UnimodalTextDecoder(tf.keras.layers.Layer):
def __init__(self, embed_dim, dense_dim, num_heads, dropout_rate=0.1, ln_epsilon=1e-6, num_layers=4, **kwargs):
super().__init__()
self.embed_dim = embed_dim
self.dense_dim = dense_dim
self.num_heads = num_heads
self.dropout_rate = dropout_rate
self.ln_epsilon = ln_epsilon
self.num_layers = num_layers
self.layers = [
TransformerBlock(self.embed_dim, self.dense_dim, self.num_heads, self.dropout_rate, self.ln_epsilon, is_multimodal=False)
for _ in range(self.num_layers)
]
self.norm = tf.keras.layers.LayerNormalization()
def call(self, x, mask=None):
for layer in self.layers:
x = layer(inputs=x, mask=mask)
return self.norm(x)
class MultimodalTextDecoder(tf.keras.layers.Layer):
def __init__(self, embed_dim, dense_dim, num_heads, dropout_rate=0.1, ln_epsilon=1e-6, num_layers=4, **kwargs):
super().__init__()
self.embed_dim = embed_dim
self.dense_dim = dense_dim
self.num_heads = num_heads
self.dropout_rate = dropout_rate
self.ln_epsilon = ln_epsilon
self.num_layers = num_layers
self.layers = [
TransformerBlock(self.embed_dim, self.dense_dim, self.num_heads, self.dropout_rate, self.ln_epsilon, is_multimodal=True)
for _ in range(self.num_layers)
]
self.norm = tf.keras.layers.LayerNormalization()
def call(self, x, encoder_outputs, mask=None):
for layer in self.layers:
x = layer(inputs=x, encoder_outputs=encoder_outputs, mask=mask)
return self.norm(x)
class EmbedToLatents(layers.Layer):
def __init__(self, dim_latents, **kwargs):
super(EmbedToLatents, self).__init__(**kwargs)
self.dim_latents = dim_latents
self.to_latents = layers.Dense(
self.dim_latents,
use_bias=False
)
def call(self, inputs):
latents = self.to_latents(inputs)
return tf.math.l2_normalize(latents, axis=-1)
class Perplexity(tf.keras.metrics.Metric):
def __init__(self, name='perplexity', **kwargs):
super().__init__(name=name, **kwargs)
self.total_loss = self.add_weight(name='total_loss', initializer='zeros')
self.total_tokens = self.add_weight(name='total_tokens', initializer='zeros')
def update_state(self, y_true, y_pred, sample_weight=None):
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')
loss = loss_fn(y_true, y_pred)
mask = tf.cast(tf.not_equal(y_true, 0), tf.float32)
loss = tf.reduce_sum(loss * mask)
num_tokens = tf.reduce_sum(mask)
self.total_loss.assign_add(loss)
self.total_tokens.assign_add(num_tokens)
def result(self):
return tf.exp(self.total_loss / self.total_tokens)
def reset_states(self):
self.total_loss.assign(0.0)
self.total_tokens.assign(0.0)
model_name = "WinKawaks/vit-tiny-patch16-224"
vit_tiny_model = TFAutoModel.from_pretrained(model_name)
vit_tiny_model.trainable = True
for layer in vit_tiny_model.layers:
layer.trainable = True
class CoCaEncoder(tf.keras.Model):
def __init__(self,
vit, **kwargs):
super().__init__(**kwargs)
self.vit = vit
self.contrastive_pooling = AttentionalPooling(attn_pool_dim, attn_pool_heads)
self.caption_pooling = AttentionalPooling(attn_pool_dim, attn_pool_heads)
self.con_query = tf.Variable(
initial_value=tf.random.normal([1, 1, proj_dim]),
trainable=True,
name="con_query"
)
self.cap_query = tf.Variable(
initial_value=tf.random.normal([1, cap_query_num, proj_dim]),
trainable=True,
name="cap_query"
)
def call(self, input, training=False):
img_feature = self.vit(input).last_hidden_state
batch_size = tf.shape(img_feature)[0]
con_query_b = tf.repeat(self.con_query, repeats=batch_size, axis=0)
cap_query_b = tf.repeat(self.cap_query, repeats=batch_size, axis=0)
con_feature = self.contrastive_pooling(img_feature, con_query_b)
cap_feature = self.caption_pooling(img_feature, cap_query_b)
return con_feature, cap_feature
class CoCaDecoder(tf.keras.Model):
def __init__(self,
cls_token_id,
num_heads,
num_layers,
**kwargs):
super().__init__(**kwargs)
self.cls_token_id = cls_token_id
self.pos_emb = PositionalEmbedding(sentence_length, vocab_size, proj_dim)
self.unimodal_decoder = UnimodalTextDecoder(
proj_dim, proj_dim * 4, num_heads, dropout_rate, num_layers=num_layers
)
self.multimodal_decoder = MultimodalTextDecoder(
proj_dim, proj_dim * 4, num_heads, dropout_rate, num_layers=num_layers
)
self.to_logits = tf.keras.layers.Dense(
vocab_size,
name='logits_projection'
)
self.norm = layers.LayerNormalization()
def call(self, inputs, training=False):
input_text, cap_feature = inputs
batch_size = tf.shape(input_text)[0]
cls_tokens = tf.fill([batch_size, 1], tf.cast(self.cls_token_id, input_text.dtype))
ids = tf.concat([input_text, cls_tokens], axis=1)
text_mask = tf.not_equal(input_text, 0)
cls_mask = tf.zeros([batch_size, 1], dtype=text_mask.dtype)
extended_mask = tf.concat([text_mask, cls_mask], axis=1)
txt_embs = self.pos_emb(ids)
unimodal_out = self.unimodal_decoder(txt_embs, mask=extended_mask)
multimodal_out = self.multimodal_decoder(unimodal_out[:, :-1, :], cap_feature, mask=text_mask)
cls_token_feature = self.norm(unimodal_out[:, -1:, :])
multimodal_logits = self.to_logits(multimodal_out)
return cls_token_feature, multimodal_logits
# день 6
class CoCaModel(tf.keras.Model):
def __init__(self,
vit,
cls_token_id,
num_heads,
num_layers):
super().__init__()
self.encoder = CoCaEncoder(vit, name="coca_encoder")
self.decoder = CoCaDecoder(cls_token_id, num_heads, num_layers, name="coca_decoder")
self.img_to_latents = EmbedToLatents(proj_dim)
self.text_to_latents = EmbedToLatents(proj_dim)
self.pad_id = 0
self.temperature = 0.2 # 0.5 #0.9 #1.0
self.caption_loss_weight = 1.0
self.contrastive_loss_weight = 1.0
self.perplexity = Perplexity()
def call(self, inputs, training=False):
image, text = inputs
con_feature, cap_feature = self.encoder(image)
cls_token_feature, multimodal_logits = self.decoder([text, cap_feature])
return con_feature, cls_token_feature, multimodal_logits
def compile(self, optimizer):
super().compile()
self.optimizer = optimizer
def compute_caption_loss(self, multimodal_out, caption_target):
caption_loss = tf.keras.losses.sparse_categorical_crossentropy(
caption_target, multimodal_out, from_logits=True, ignore_class=self.pad_id)
return tf.reduce_mean(caption_loss)
def compute_contrastive_loss(self, con_feature, cls_feature):
text_embeds = tf.squeeze(cls_feature, axis=1)
image_embeds = tf.squeeze(con_feature, axis=1)
text_latents = self.text_to_latents(text_embeds)
image_latents = self.img_to_latents(image_embeds)
# Матрица схожести
sim = tf.matmul(text_latents, image_latents, transpose_b=True) / self.temperature # tf.exp(self.log_temp)
# Метки
batch_size = tf.shape(sim)[0]
contrastive_labels = tf.range(batch_size)
# Вычисление потерь
loss1 = tf.keras.losses.sparse_categorical_crossentropy(contrastive_labels, sim, from_logits=True)
loss2 = tf.keras.losses.sparse_categorical_crossentropy(contrastive_labels, tf.transpose(sim), from_logits=True)
contrastive_loss = tf.reduce_mean((loss1 + loss2) * 0.5)
return contrastive_loss
def train_step(self, data):
(images, caption_input), caption_target = data
with tf.GradientTape() as tape:
con_feature, cls_feature, multimodal_out = self([images, caption_input], training=True)
caption_loss = self.compute_caption_loss(multimodal_out, caption_target)
contrastive_loss = self.compute_contrastive_loss(con_feature, cls_feature)
total_loss = self.caption_loss_weight * caption_loss + self.contrastive_loss_weight * contrastive_loss
gradients = tape.gradient(total_loss, self.trainable_variables)
self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
self.perplexity.update_state(caption_target, multimodal_out)
return {
'total_loss': total_loss,
'caption_loss': caption_loss,
'contrastive_loss': contrastive_loss,
'perplexity': self.perplexity.result()
}
def test_step(self, data):
(images, caption_input), caption_target = data
con_feature, cls_feature, multimodal_out = self([images, caption_input], training=False)
caption_loss = self.compute_caption_loss(multimodal_out, caption_target)
contrastive_loss = self.compute_contrastive_loss(con_feature, cls_feature)
total_loss = self.caption_loss_weight * caption_loss + self.contrastive_loss_weight * contrastive_loss
self.perplexity.update_state(caption_target, multimodal_out)
return {
'total_loss': total_loss,
'caption_loss': caption_loss,
'contrastive_loss': contrastive_loss,
'perplexity': self.perplexity.result()
}
def reset_metrics(self):
self.perplexity.reset_state()
# ===========================================
# Загрузка весов для коки
coca_model = CoCaModel(vit_tiny_model, cls_token_id=cls_token_id, num_heads=num_heads, num_layers=num_layers)
dummy_features = tf.zeros((1, 3, img_size, img_size), dtype=tf.float32)
dummy_captions = tf.zeros((1, sentence_length-1), dtype=tf.int64)
_ = coca_model((dummy_features, dummy_captions))
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)
coca_model.compile(optimizer)
save_dir = "models/"
model_name = "coca_007"
coca_model.load_weights(f"{save_dir}/{model_name}.weights.h5")
# ===========================================
# RNN =======================================
img_embed_dim = 2048
reg_count = 7 * 7
base_model = ResNet50(weights='imagenet', include_top=False)
model = Model(inputs=base_model.input, outputs=base_model.output)
def preprocess_image(img):
img = tf.image.resize(img, (img_size, img_size))
img = tf.convert_to_tensor(img)
img = preprocess_input(img)
return np.expand_dims(img, axis=0)
def create_features(img):
img = preprocess_image(img)
features = model.predict(img, verbose=0)
features = features.reshape((1, reg_count, img_embed_dim))
return features
class BahdanauAttention(layers.Layer):
def __init__(self, units, **kwargs):
super().__init__(**kwargs)
self.units = units
self.W1 = layers.Dense(units)
self.W2 = layers.Dense(units)
self.V = layers.Dense(1)
def call(self, features, hidden):
hidden = tf.expand_dims(hidden, 1)
score = self.V(tf.nn.tanh(
self.W1(features) + self.W2(hidden)
))
alpha = tf.nn.softmax(score, axis=1)
context = tf.reduce_sum(alpha * features, axis=1)
return context, alpha
class ImageCaptioningModel(tf.keras.Model):
def __init__(self, vocab_size, max_caption_len, embedding_dim=512, lstm_units=512, dropout_rate=0.5, **kwargs):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.max_caption_len = max_caption_len
self.embedding_dim = embedding_dim
self.lstm_units = lstm_units
self.dropout_rate = dropout_rate
self.embedding = layers.Embedding(vocab_size, embedding_dim)
self.embedding_dropout = layers.Dropout(dropout_rate)
self.lstm = layers.LSTM(lstm_units, return_sequences=True, return_state=True)
self.attention = BahdanauAttention(lstm_units)
self.fc_dropout = layers.Dropout(dropout_rate)
self.fc = layers.Dense(vocab_size, activation='softmax')
self.init_h = layers.Dense(lstm_units, activation='tanh')
self.init_c = layers.Dense(lstm_units)
self.concatenate = layers.Concatenate(axis=-1)
def call(self, inputs):
features, captions = inputs
mean_features = tf.reduce_mean(features, axis=1)
h = self.init_h(mean_features)
c = self.init_c(mean_features)
embeddings = self.embedding(captions)
embeddings = self.embedding_dropout(embeddings)
outputs = []
for t in range(self.max_caption_len):
context, _ = self.attention(features, h)
lstm_input = self.concatenate([embeddings[:, t, :], context])
lstm_input = tf.expand_dims(lstm_input, 1)
output, h, c = self.lstm(lstm_input, initial_state=[h, c])
outputs.append(output)
outputs = tf.concat(outputs, axis=1)
outputs = self.fc_dropout(outputs)
return self.fc(outputs)
rnn_model = ImageCaptioningModel(vocab_size, sentence_length-1, rnn_embedding_dim, rnn_proj_dim)
image_input = np.random.rand(batch_size, reg_count, img_embed_dim).astype(np.float32)
text_input = np.random.randint(0, 10000, size=(batch_size, sentence_length))
_ = rnn_model([image_input, text_input])
rnn_model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
metrics=[Perplexity()]
)
save_dir = "models/"
model_name = "rnn_att_v4"
rnn_model.load_weights(f"{save_dir}/{model_name}.weights.h5")
# =====================================
# Методы генерации
beam_width=3
max_length=sentence_length-1
temperature=1.0
image_mean = [0.5, 0.5, 0.5]
image_std = [0.5, 0.5, 0.5]
def load_and_preprocess_image(img):
img = tf.convert_to_tensor(img)
img = tf.image.resize(img, (img_size, img_size))
img = img / 255.0
img = (img - image_mean) / image_std
img = tf.transpose(img, perm=[2, 0, 1])
return np.expand_dims(img, axis=0)
def has_repeated_ngrams(seq, n=2):
ngrams = [tuple(seq[i:i+n]) for i in range(len(seq)-n+1)]
return len(ngrams) != len(set(ngrams))
# метод с улучшениями для коки
def generate_caption_coca(image):
img_processed = load_and_preprocess_image(image)
_, cap_features = coca_model.encoder.predict(img_processed, verbose=0)
beams = [([word_index[start_token]], 0.0)]
for _ in range(max_length):
new_beams = []
for seq, log_prob in beams:
if seq[-1] == word_index[end_token]:
new_beams.append((seq, log_prob))
continue
text_input = np.zeros((1, max_length), dtype=np.int32)
text_input[0, :len(seq)] = seq
predictions = coca_model.decoder.predict([text_input, cap_features], verbose=0)
_, logits = predictions
logits = logits[0, len(seq)-1, :] / temperature
probs = np.exp(logits - np.max(logits))
probs /= probs.sum()
top_k = np.argpartition(probs, -beam_width)[-beam_width:]
for token in top_k:
new_seq = seq + [token]
new_log_prob = (log_prob * len(seq) + np.log(probs[token])) / (len(seq) + 1)
# Штраф за повторения
if has_repeated_ngrams(new_seq, n=2):
new_log_prob -= 0.5
new_beams.append((new_seq, new_log_prob))
beams = sorted(new_beams, key=lambda x: x[1], reverse=True)[:beam_width]
if all(beam[0][-1] == word_index[end_token] for beam in beams):
break
best_seq = max(beams, key=lambda x: x[1])[0]
return " ".join(index_word[i] for i in best_seq if i not in {word_index[start_token], word_index[end_token]})
# метод с улучшениями для rnn
def generate_caption_rnn(image):
image_embedding = create_features(image)
beams = [([word_index[start_token]], 0.0)]
for _ in range(max_length):
new_beams = []
for seq, log_prob in beams:
if seq[-1] == word_index[end_token]:
new_beams.append((seq, log_prob))
continue
text_input = np.zeros((1, max_length), dtype=np.int32)
text_input[0, :len(seq)] = seq
predictions = rnn_model.predict([image_embedding, text_input], verbose=0)
probs = predictions[0, len(seq)-1, :]
probs = probs ** (1 / temperature)
probs /= probs.sum()
top_k = np.argpartition(probs, -beam_width)[-beam_width:]
for token in top_k:
new_seq = seq + [token]
new_log_prob = (log_prob * len(seq) + np.log(probs[token])) / (len(seq) + 1)
# Штраф за повторения
if has_repeated_ngrams(new_seq, n=2):
new_log_prob -= 0.5
new_beams.append((new_seq, new_log_prob))
beams = sorted(new_beams, key=lambda x: x[1], reverse=True)[:beam_width]
if all(beam[0][-1] == word_index[end_token] for beam in beams):
break
best_seq = max(beams, key=lambda x: x[1])[0]
return " ".join(index_word[i] for i in best_seq if i not in {word_index[start_token], word_index[end_token]})
def generate_both(image):
caption1 = generate_caption_rnn(image)
caption2 = generate_caption_coca(image)
return f"RNN: {caption1}\n\nCoCa: {caption2}"
# interface = gr.Interface(
# fn=generate_both,
# inputs=gr.Image(type="pil", label="Изображение"),
# outputs=gr.Textbox(label="Описания", autoscroll=True, show_copy_button=True),
# title="Генератор описаний к изображениям",
# allow_flagging="never",
# submit_btn="Сгенерировать",
# clear_btn="Очистить"
# )
#------------------------------
css = """
#hosted-by-hf {
top: unset !important;
bottom: 20px !important;
right: 20px !important;
}
"""
interface = gr.Interface(
fn=generate_both,
inputs=gr.Image(type="pil", label="Изображение"),
outputs=gr.Textbox(label="Описания", autoscroll=True, show_copy_button=True),
allow_flagging="never",
submit_btn="Сгенерировать",
clear_btn="Очистить",
deep_link=False
)
with gr.Blocks(css=css) as demo:
gr.Markdown("# 🖼️ Генератор описаний к изображениям")
interface.render()
# if __name__ == "__main__":
# #interface.launch(ssr_mode=False)
# demo.launch(ssr_mode=False)
# custom_css = """
# footer {visibility: hidden !important;}
# .share-button {display: none !important;}
# #component-1 {margin-top: -1.5rem !important;} # Уменьшаем отступ сверху
# """
# interface = gr.Interface(
# fn=generate_both,
# inputs=gr.Image(type="pil", label="Изображение"),
# outputs=gr.Textbox(label="Описания", autoscroll=True, show_copy_button=True),
# allow_flagging="never",
# submit_btn="Сгенерировать",
# clear_btn="Очистить"
# )
# with gr.Blocks(css=custom_css) as demo:
# gr.Markdown("## 🖼️ Генератор описаний к изображениям")
# interface.render()
if __name__ == "__main__":
demo.launch(
ssr_mode=False,
show_api=False
)