from tensorflow import keras import numpy as np import tensorflow as tf from tensorflow import data as tf_data from tensorflow import image as tf_image from tensorflow import io as tf_io from PIL import Image import json from tensorflow.keras import layers, Model import string from transformers import TFAutoModel import gradio as gr import os import numpy as np from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input from tensorflow.keras.preprocessing import image from tensorflow.keras.models import Model os.environ["KERAS_BACKEND"] = "tensorflow" start_token = "[BOS]" end_token = "[EOS]" cls_token = "[CLS]" data_dir = '/content/coco' data_type_train = 'train2014' data_type_val = 'val2014' vocab_size = 24000 sentence_length = 20 batch_size = 128 img_size = 224 proj_dim = 192 dropout_rate = 0.1 num_patches = 14 patch_size = img_size // num_patches num_heads = 3 num_layers = 6 attn_pool_dim = proj_dim attn_pool_heads = num_heads cap_query_num = 128 rnn_embedding_dim = 256 rnn_proj_dim = 512 with open('vocabs/word_index.json', 'r', encoding='utf-8') as f: word_index = {np.str_(word): np.int64(idx) for word, idx in json.load(f).items()} with open('vocabs/index_word.json', 'r', encoding='utf-8') as f: index_word = {np.int64(idx): np.str_(word) for idx, word in json.load(f).items()} cls_token_id = word_index[cls_token] class PositionalEmbedding(layers.Layer): def __init__(self, sequence_length, input_dim, output_dim, **kwargs): super().__init__(**kwargs) self.sequence_length = sequence_length self.input_dim = input_dim self.output_dim = output_dim self.token_embeddings = layers.Embedding( input_dim=input_dim, output_dim=output_dim ) self.position_embeddings = layers.Embedding( input_dim=sequence_length, output_dim=output_dim ) def call(self, inputs): positions = tf.range(start=0, limit=self.sequence_length, delta=1) embedded_tokens = self.token_embeddings(inputs) embedded_positions = self.position_embeddings(positions) output = embedded_tokens + embedded_positions return output class AttentionalPooling(layers.Layer): def __init__(self, embed_dim, num_heads=6): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.multihead_attention = layers.MultiHeadAttention(num_heads=self.num_heads, key_dim=self.embed_dim) self.norm = layers.LayerNormalization() def call(self, features, query): attn_output = self.multihead_attention( query=query, value=features, key=features ) return self.norm(attn_output) class TransformerBlock(layers.Layer): def __init__(self, embed_dim, dense_dim, num_heads, dropout_rate=0.1, ln_epsilon=1e-6, is_multimodal=False, **kwargs): super().__init__(**kwargs) self.embed_dim = embed_dim self.dense_dim = dense_dim self.num_heads = num_heads self.dropout_rate = dropout_rate self.ln_epsilon = ln_epsilon self.self_attention = layers.MultiHeadAttention( num_heads=self.num_heads, key_dim=self.embed_dim, dropout=self.dropout_rate ) if is_multimodal: self.norm2 = layers.LayerNormalization(epsilon=self.ln_epsilon) self.dropout2 = layers.Dropout(self.dropout_rate) self.cross_attention = layers.MultiHeadAttention( num_heads=self.num_heads, key_dim=self.embed_dim, dropout=self.dropout_rate ) self.dense_proj = tf.keras.Sequential([ layers.Dense(self.dense_dim, activation="gelu"), layers.Dropout(self.dropout_rate), layers.Dense(self.embed_dim) ]) self.norm1 = layers.LayerNormalization(epsilon=self.ln_epsilon) self.norm3 = layers.LayerNormalization(epsilon=self.ln_epsilon) self.dropout1 = layers.Dropout(self.dropout_rate) self.dropout3 = layers.Dropout(self.dropout_rate) def get_causal_attention_mask(self, inputs): seq_len = tf.shape(inputs)[1] causal_mask = tf.linalg.band_part(tf.ones((seq_len, seq_len), tf.bool), -1, 0) return tf.expand_dims(causal_mask, 0) def get_combined_mask(self, causal_mask, padding_mask): padding_mask = tf.cast(padding_mask, tf.bool) padding_mask = tf.expand_dims(padding_mask, 1) return causal_mask & padding_mask def call(self, inputs, encoder_outputs=None, mask=None): att_mask = self.get_causal_attention_mask(inputs) if mask is not None: att_mask = self.get_combined_mask(att_mask, mask) x = self.norm1(inputs) attention_output_1 = self.self_attention( query=x, key=x, value=x, attention_mask=att_mask ) attention_output_1 = self.dropout1(attention_output_1) x = x + attention_output_1 if encoder_outputs is not None: x_norm = self.norm2(x) attention_output_2 = self.cross_attention( query=x_norm, key=encoder_outputs, value=encoder_outputs ) attention_output_2 = self.dropout2(attention_output_2) x = x + attention_output_2 x_norm = self.norm3(x) proj_output = self.dense_proj(x_norm) proj_output = self.dropout3(proj_output) return x + proj_output class UnimodalTextDecoder(layers.Layer): def __init__(self, embed_dim, dense_dim, num_heads, dropout_rate=0.1, ln_epsilon=1e-6, num_layers=4, **kwargs): super().__init__() self.embed_dim = embed_dim self.dense_dim = dense_dim self.num_heads = num_heads self.dropout_rate = dropout_rate self.ln_epsilon = ln_epsilon self.num_layers = num_layers self.layers = [ TransformerBlock(self.embed_dim, self.dense_dim, self.num_heads, self.dropout_rate, self.ln_epsilon, is_multimodal=False) for _ in range(self.num_layers) ] self.norm = tf.keras.layers.LayerNormalization() def call(self, x, mask=None): for layer in self.layers: x = layer(inputs=x, mask=mask) return self.norm(x) class MultimodalTextDecoder(layers.Layer): def __init__(self, embed_dim, dense_dim, num_heads, dropout_rate=0.1, ln_epsilon=1e-6, num_layers=4, **kwargs): super().__init__() self.embed_dim = embed_dim self.dense_dim = dense_dim self.num_heads = num_heads self.dropout_rate = dropout_rate self.ln_epsilon = ln_epsilon self.num_layers = num_layers self.layers = [ TransformerBlock(self.embed_dim, self.dense_dim, self.num_heads, self.dropout_rate, self.ln_epsilon, is_multimodal=True) for _ in range(self.num_layers) ] self.norm = tf.keras.layers.LayerNormalization() def call(self, x, encoder_outputs, mask=None): for layer in self.layers: x = layer(inputs=x, encoder_outputs=encoder_outputs, mask=mask) return self.norm(x) class EmbedToLatents(layers.Layer): def __init__(self, dim_latents, **kwargs): super(EmbedToLatents, self).__init__(**kwargs) self.dim_latents = dim_latents self.to_latents = layers.Dense( self.dim_latents, use_bias=False ) def call(self, inputs): latents = self.to_latents(inputs) return tf.math.l2_normalize(latents, axis=-1) class Perplexity(tf.keras.metrics.Metric): def __init__(self, name='perplexity', **kwargs): super().__init__(name=name, **kwargs) self.total_loss = self.add_weight(name='total_loss', initializer='zeros') self.total_tokens = self.add_weight(name='total_tokens', initializer='zeros') def update_state(self, y_true, y_pred, sample_weight=None): loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none') loss = loss_fn(y_true, y_pred) mask = tf.cast(tf.not_equal(y_true, 0), tf.float32) loss = tf.reduce_sum(loss * mask) num_tokens = tf.reduce_sum(mask) self.total_loss.assign_add(loss) self.total_tokens.assign_add(num_tokens) def result(self): return tf.exp(self.total_loss / self.total_tokens) def reset_states(self): self.total_loss.assign(0.0) self.total_tokens.assign(0.0) model_name = "WinKawaks/vit-tiny-patch16-224" vit_tiny_model = TFAutoModel.from_pretrained(model_name) vit_tiny_model.trainable = True for layer in vit_tiny_model.layers: layer.trainable = True class CoCaEncoder(tf.keras.Model): def __init__(self, vit, **kwargs): super().__init__(**kwargs) self.vit = vit self.contrastive_pooling = AttentionalPooling(attn_pool_dim, attn_pool_heads) self.caption_pooling = AttentionalPooling(attn_pool_dim, attn_pool_heads) self.con_query = tf.Variable( initial_value=tf.random.normal([1, 1, proj_dim]), trainable=True, name="con_query" ) self.cap_query = tf.Variable( initial_value=tf.random.normal([1, cap_query_num, proj_dim]), trainable=True, name="cap_query" ) def call(self, input, training=False): img_feature = self.vit(input).last_hidden_state batch_size = tf.shape(img_feature)[0] con_query_b = tf.repeat(self.con_query, repeats=batch_size, axis=0) cap_query_b = tf.repeat(self.cap_query, repeats=batch_size, axis=0) con_feature = self.contrastive_pooling(img_feature, con_query_b) cap_feature = self.caption_pooling(img_feature, cap_query_b) return con_feature, cap_feature class CoCaDecoder(tf.keras.Model): def __init__(self, cls_token_id, num_heads, num_layers, **kwargs): super().__init__(**kwargs) self.cls_token_id = cls_token_id self.pos_emb = PositionalEmbedding(sentence_length, vocab_size, proj_dim) self.unimodal_decoder = UnimodalTextDecoder( proj_dim, proj_dim * 4, num_heads, dropout_rate, num_layers=num_layers ) self.multimodal_decoder = MultimodalTextDecoder( proj_dim, proj_dim * 4, num_heads, dropout_rate, num_layers=num_layers ) self.to_logits = tf.keras.layers.Dense( vocab_size, name='logits_projection' ) self.norm = layers.LayerNormalization() def call(self, inputs, training=False): input_text, cap_feature = inputs batch_size = tf.shape(input_text)[0] cls_tokens = tf.fill([batch_size, 1], tf.cast(self.cls_token_id, input_text.dtype)) ids = tf.concat([input_text, cls_tokens], axis=1) text_mask = tf.not_equal(input_text, 0) cls_mask = tf.zeros([batch_size, 1], dtype=text_mask.dtype) extended_mask = tf.concat([text_mask, cls_mask], axis=1) txt_embs = self.pos_emb(ids) unimodal_out = self.unimodal_decoder(txt_embs, mask=extended_mask) multimodal_out = self.multimodal_decoder(unimodal_out[:, :-1, :], cap_feature, mask=text_mask) cls_token_feature = self.norm(unimodal_out[:, -1:, :]) multimodal_logits = self.to_logits(multimodal_out) return cls_token_feature, multimodal_logits class CoCaModel(tf.keras.Model): def __init__(self, vit, cls_token_id, num_heads, num_layers): super().__init__() self.encoder = CoCaEncoder(vit, name="coca_encoder") self.decoder = CoCaDecoder(cls_token_id, num_heads, num_layers, name="coca_decoder") self.img_to_latents = EmbedToLatents(proj_dim) self.text_to_latents = EmbedToLatents(proj_dim) self.pad_id = 0 self.temperature = 0.07 self.caption_loss_weight = 1.0 self.contrastive_loss_weight = 1.0 self.perplexity = Perplexity() def call(self, inputs, training=False): image, text = inputs con_feature, cap_feature = self.encoder(image) cls_token_feature, multimodal_logits = self.decoder([text, cap_feature]) return con_feature, cls_token_feature, multimodal_logits def compile(self, optimizer): super().compile() self.optimizer = optimizer def compute_caption_loss(self, multimodal_out, caption_target): caption_loss = tf.keras.losses.sparse_categorical_crossentropy( caption_target, multimodal_out, from_logits=True, ignore_class=self.pad_id) return tf.reduce_mean(caption_loss) def compute_contrastive_loss(self, con_feature, cls_feature): text_embeds = tf.squeeze(cls_feature, axis=1) image_embeds = tf.squeeze(con_feature, axis=1) text_latents = self.text_to_latents(text_embeds) image_latents = self.img_to_latents(image_embeds) sim = tf.matmul(text_latents, image_latents, transpose_b=True) / self.temperature batch_size = tf.shape(sim)[0] contrastive_labels = tf.range(batch_size) loss1 = tf.keras.losses.sparse_categorical_crossentropy(contrastive_labels, sim, from_logits=True) loss2 = tf.keras.losses.sparse_categorical_crossentropy(contrastive_labels, tf.transpose(sim), from_logits=True) contrastive_loss = tf.reduce_mean((loss1 + loss2) * 0.5) return contrastive_loss def train_step(self, data): (images, caption_input), caption_target = data with tf.GradientTape() as tape: con_feature, cls_feature, multimodal_out = self([images, caption_input], training=True) caption_loss = self.compute_caption_loss(multimodal_out, caption_target) contrastive_loss = self.compute_contrastive_loss(con_feature, cls_feature) total_loss = self.caption_loss_weight * caption_loss + self.contrastive_loss_weight * contrastive_loss gradients = tape.gradient(total_loss, self.trainable_variables) self.optimizer.apply_gradients(zip(gradients, self.trainable_variables)) self.perplexity.update_state(caption_target, multimodal_out) return { 'total_loss': total_loss, 'caption_loss': caption_loss, 'contrastive_loss': contrastive_loss, 'perplexity': self.perplexity.result() } def test_step(self, data): (images, caption_input), caption_target = data con_feature, cls_feature, multimodal_out = self([images, caption_input], training=False) caption_loss = self.compute_caption_loss(multimodal_out, caption_target) contrastive_loss = self.compute_contrastive_loss(con_feature, cls_feature) total_loss = self.caption_loss_weight * caption_loss + self.contrastive_loss_weight * contrastive_loss self.perplexity.update_state(caption_target, multimodal_out) return { 'total_loss': total_loss, 'caption_loss': caption_loss, 'contrastive_loss': contrastive_loss, 'perplexity': self.perplexity.result() } def reset_metrics(self): self.perplexity.reset_state() coca_model = CoCaModel(vit_tiny_model, cls_token_id=cls_token_id, num_heads=num_heads, num_layers=num_layers) dummy_features = tf.zeros((1, 3, img_size, img_size), dtype=tf.float32) dummy_captions = tf.zeros((1, sentence_length-1), dtype=tf.int64) _ = coca_model((dummy_features, dummy_captions)) optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4) coca_model.compile(optimizer) save_dir = "models/" model_name = "coca" coca_model.load_weights(f"{save_dir}/{model_name}.weights.h5") img_embed_dim = 2048 reg_count = 7 * 7 base_model = ResNet50(weights='imagenet', include_top=False) model = Model(inputs=base_model.input, outputs=base_model.output) def preprocess_image(img): img = tf.image.resize(img, (img_size, img_size)) img = tf.convert_to_tensor(img) img = preprocess_input(img) return np.expand_dims(img, axis=0) def create_features(img): img = preprocess_image(img) features = model.predict(img, verbose=0) features = features.reshape((1, reg_count, img_embed_dim)) return features class BahdanauAttention(layers.Layer): def __init__(self, units, **kwargs): super().__init__(**kwargs) self.units = units self.W1 = layers.Dense(units) self.W2 = layers.Dense(units) self.V = layers.Dense(1) def call(self, features, hidden): hidden = tf.expand_dims(hidden, 1) score = self.V(tf.nn.tanh( self.W1(features) + self.W2(hidden) )) alpha = tf.nn.softmax(score, axis=1) context = tf.reduce_sum(alpha * features, axis=1) return context, alpha class ImageCaptioningModel(tf.keras.Model): def __init__(self, vocab_size, max_caption_len, embedding_dim=512, lstm_units=512, dropout_rate=0.5, **kwargs): super().__init__(**kwargs) self.vocab_size = vocab_size self.max_caption_len = max_caption_len self.embedding_dim = embedding_dim self.lstm_units = lstm_units self.dropout_rate = dropout_rate self.embedding = layers.Embedding(vocab_size, embedding_dim) self.embedding_dropout = layers.Dropout(dropout_rate) self.lstm = layers.LSTM(lstm_units, return_sequences=True, return_state=True) self.attention = BahdanauAttention(lstm_units) self.fc_dropout = layers.Dropout(dropout_rate) self.fc = layers.Dense(vocab_size, activation='softmax') self.init_h = layers.Dense(lstm_units, activation='tanh') self.init_c = layers.Dense(lstm_units) self.concatenate = layers.Concatenate(axis=-1) def call(self, inputs): features, captions = inputs mean_features = tf.reduce_mean(features, axis=1) h = self.init_h(mean_features) c = self.init_c(mean_features) embeddings = self.embedding(captions) embeddings = self.embedding_dropout(embeddings) outputs = [] for t in range(self.max_caption_len): context, _ = self.attention(features, h) lstm_input = self.concatenate([embeddings[:, t, :], context]) lstm_input = tf.expand_dims(lstm_input, 1) output, h, c = self.lstm(lstm_input, initial_state=[h, c]) outputs.append(output) outputs = tf.concat(outputs, axis=1) outputs = self.fc_dropout(outputs) return self.fc(outputs) rnn_model = ImageCaptioningModel(vocab_size, sentence_length-1, rnn_embedding_dim, rnn_proj_dim) image_input = np.random.rand(batch_size, reg_count, img_embed_dim).astype(np.float32) text_input = np.random.randint(0, 10000, size=(batch_size, sentence_length)) _ = rnn_model([image_input, text_input]) rnn_model.compile( optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4), loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False), metrics=[Perplexity()] ) save_dir = "models/" model_name = "rnn_attn" rnn_model.load_weights(f"{save_dir}/{model_name}.weights.h5") beam_width=3 max_length=sentence_length-1 temperature=1.0 image_mean = [0.5, 0.5, 0.5] image_std = [0.5, 0.5, 0.5] def load_and_preprocess_image(img): img = tf.convert_to_tensor(img) img = tf.image.resize(img, (img_size, img_size)) img = img / 255.0 img = (img - image_mean) / image_std img = tf.transpose(img, perm=[2, 0, 1]) return np.expand_dims(img, axis=0) def has_repeated_ngrams(seq, n=2): ngrams = [tuple(seq[i:i+n]) for i in range(len(seq)-n+1)] return len(ngrams) != len(set(ngrams)) image_mean = [0.5, 0.5, 0.5] image_std = [0.5, 0.5, 0.5] def load_and_preprocess_image(img): #img = tf.image.decode_jpeg(img, channels=3) img = tf.convert_to_tensor(img) img = tf.image.resize(img, (img_size, img_size)) img = img / 255.0 img = (img - image_mean) / image_std img = tf.transpose(img, perm=[2, 0, 1]) return np.expand_dims(img, axis=0) # def generate_caption_coca(image): # img_processed = load_and_preprocess_image(image) # _, cap_features = coca_model.encoder.predict(img_processed, verbose=0) # cap_features = cap_features.astype(np.float32) # start_token_id = word_index[start_token] # end_token_id = word_index[end_token] # sequence = [start_token_id] # text_input = np.zeros((1, sentence_length - 1), dtype=np.float32) # for t in range(sentence_length - 1): # text_input[0, :len(sequence)] = sequence # _, logits = coca_model.decoder.predict( # [text_input, cap_features], # verbose=0 # ) # next_token = np.argmax(logits[0, t, :]) # sequence.append(next_token) # if next_token == end_token_id or len(sequence) >= (sentence_length - 1): # break # caption = " ".join( # [index_word[token] for token in sequence # if token not in {word_index[start_token], word_index[end_token]}] # ) # return caption def generate_caption_coca(image): img_processed = load_and_preprocess_image(image) _, cap_features = coca_model.encoder.predict(img_processed, verbose=0) beams = [([word_index[start_token]], 0.0)] for _ in range(max_length): new_beams = [] for seq, log_prob in beams: if seq[-1] == word_index[end_token]: new_beams.append((seq, log_prob)) continue text_input = np.zeros((1, max_length), dtype=np.int32) text_input[0, :len(seq)] = seq predictions = coca_model.decoder.predict([text_input, cap_features], verbose=0) _, logits = predictions logits = logits[0, len(seq)-1, :] probs = np.exp(logits - np.max(logits)) probs /= probs.sum() top_k = np.argsort(-probs)[:beam_width] for token in top_k: new_seq = seq + [token] new_log_prob = (log_prob * len(seq) + np.log(probs[token])) / (len(seq) + 1) if has_repeated_ngrams(new_seq, n=2): new_log_prob -= 0.5 new_beams.append((new_seq, new_log_prob)) beams = sorted(new_beams, key=lambda x: x[1], reverse=True)[:beam_width] if all(beam[0][-1] == word_index[end_token] for beam in beams): break best_seq = max(beams, key=lambda x: x[1])[0] return " ".join(index_word[i] for i in best_seq if i not in {word_index[start_token], word_index[end_token]}) def generate_caption_rnn(image): image_embedding = create_features(image) beams = [([word_index[start_token]], 0.0)] for _ in range(max_length): new_beams = [] for seq, log_prob in beams: if seq[-1] == word_index[end_token]: new_beams.append((seq, log_prob)) continue text_input = np.zeros((1, max_length), dtype=np.int32) text_input[0, :len(seq)] = seq predictions = rnn_model.predict([image_embedding, text_input], verbose=0) probs = predictions[0, len(seq)-1, :] probs = probs ** (1 / temperature) probs /= probs.sum() top_k = np.argpartition(probs, -beam_width)[-beam_width:] for token in top_k: new_seq = seq + [token] new_log_prob = (log_prob * len(seq) + np.log(probs[token])) / (len(seq) + 1) if has_repeated_ngrams(new_seq, n=2): new_log_prob -= 0.5 new_beams.append((new_seq, new_log_prob)) beams = sorted(new_beams, key=lambda x: x[1], reverse=True)[:beam_width] if all(beam[0][-1] == word_index[end_token] for beam in beams): break best_seq = max(beams, key=lambda x: x[1])[0] return " ".join(index_word[i] for i in best_seq if i not in {word_index[start_token], word_index[end_token]}) def generate_both(image): caption1 = generate_caption_rnn(image) caption2 = generate_caption_coca(image) return f"RNN: {caption1}\n\nCoCa: {caption2}" interface = gr.Interface( fn=generate_both, inputs=gr.Image(type="pil", label="Изображение"), outputs=gr.Textbox(label="Описания", autoscroll=True, show_copy_button=True), allow_flagging="never", submit_btn="Сгенерировать", clear_btn="Очистить", deep_link=False ) with gr.Blocks() as demo: gr.Markdown("# 🖼️ Генератор описаний к изображениям") interface.render() if __name__ == "__main__": demo.launch(ssr_mode=False, show_api=False)