from ..imports import * class Mask(keras.layers.Layer): def __init__(self): super().__init__() def call(self, inputs, **kwargs): length = tf.shape(inputs)[-2] mask = tf.sequence_mask(tf.range(length + 1)[1:], length) return tf.where(mask, 0., -1e9) class SeqEmbedding(keras.layers.Layer): def __init__(self): super().__init__() self.token_emb = keras.layers.Embedding( VOCAB_SIZE, TEXT_EMBEDDING_DIM) self.position = keras.layers.Embedding( INPUT_SEQ_LENGTH, TEXT_EMBEDDING_DIM ) def call(self, txt): return self.token_emb(txt) + self.position(tf.range(INPUT_SEQ_LENGTH))[tf.newaxis] class Attention(keras.layers.Layer): def __init__(self): super().__init__() self.Q = keras.layers.Dense(TEXT_EMBEDDING_DIM) self.K = keras.layers.Dense(TEXT_EMBEDDING_DIM) self.V = keras.layers.Dense(TEXT_EMBEDDING_DIM) def call(self, q, k, v, mask=None, **kwargs): Q = self.Q(q) K = self.K(k) V = self.V(v) attention = tf.matmul(Q, K, transpose_b=True) / (TEXT_EMBEDDING_DIM ** 0.5) if mask is not None: attention += mask attention = tf.nn.softmax(attention, axis=-1) return tf.matmul(attention, V) class SelfAttention(keras.layers.Layer): def __init__(self): super().__init__() self.sa = Attention() self.norm = keras.layers.LayerNormalization() self.mask = Mask() def call(self, inputs, **kwargs): mask = self.mask(inputs) return self.norm(self.sa(inputs, inputs, inputs, mask) + inputs) class CrossAttention(keras.layers.Layer): def __init__(self): super().__init__() self.ca = Attention() self.norm = keras.layers.LayerNormalization() def call(self, src, mem, **kwargs): return self.norm(self.ca(src, mem, mem) + src) class DecoderLayer(keras.layers.Layer): def __init__(self): super().__init__() self.sa = SelfAttention() self.ca = CrossAttention() self.ff = FeedForward() self.norm = keras.layers.LayerNormalization() self.mask = Mask() def call(self, inp): seq, enc = inp sa = self.sa(seq) ca = self.ca(sa, enc) return self.ff(ca) class EncoderLayer(keras.layers.Layer): def __init__(self): super().__init__() self.attn = Attention() self.norm = keras.layers.LayerNormalization() def call(self, inputs): out = self.attn(inputs, inputs, inputs) + inputs return self.norm(out) class FeedForward(keras.layers.Layer): def __init__(self): super().__init__() self.seq = keras.Sequential([ keras.layers.Dense(TEXT_EMBEDDING_DIM * 2, activation='relu'), keras.layers.Dense(TEXT_EMBEDDING_DIM), keras.layers.Dropout(0.3), ]) self.norm = keras.layers.LayerNormalization() def call(self, inputs, **kwargs): return self.norm(self.seq(inputs) + inputs) class FastCaption(keras.Model): def __init__(self): super().__init__() self.backbone = keras.applications.VGG16(include_top=False) self.backbone.trainable = False self.decoder = DecoderLayer() self.encoder = EncoderLayer() self.adapt = keras.layers.Dense(TEXT_EMBEDDING_DIM) self.dense = keras.layers.Dense(VOCAB_SIZE, activation='softmax') self.embedding = SeqEmbedding() def call(self, inputs): img, txt = inputs img = keras.applications.vgg16.preprocess_input(img) img = self.backbone(img) # (batch, 8, 8, 2048) img = self.adapt(img) # (batch, 8, 8, 768) img = tf.reshape(img, [tf.shape(img)[0], tf.shape(img)[1] * tf.shape(img)[2], tf.shape(img)[3]]) img = self.encoder(img) seq = self.embedding(txt) # (b, p, length, dim) out = self.decoder((seq, img[:, tf.newaxis])) return self.dense(out) @staticmethod @tf.function def generate_caption(model, img): img = keras.applications.vgg16.preprocess_input(img) img = model.backbone(img) img = model.adapt(img) img = tf.reshape(img, [tf.shape(img)[0], tf.shape(img)[1] * tf.shape(img)[2], tf.shape(img)[3]]) img = model.encoder(img) txt = tf.zeros((1, 1, INPUT_SEQ_LENGTH), dtype=tf.int32) return FastCaption._generate_from_seed(model, img, txt, tf.constant(0, dtype=tf.int32)) @staticmethod @tf.function def _generate_from_seed(model, img, txt, index): while tf.math.logical_not(tf.math.logical_or(tf.math.reduce_any(tf.equal(txt, 2)), tf.math.reduce_all(tf.not_equal(txt, 0)))): seq = model.embedding(txt) out = model.decoder((seq, img)) prob = model.dense(out) new_text = tf.argmax(prob, axis=-1, output_type=tf.int32) valid = tf.cast(tf.range(tf.shape(txt)[2]) <= index, dtype=tf.int32) new_text = new_text * valid txt = tf.concat([tf.ones((1, 1, 1), dtype=tf.int32), new_text[:, :, :-1]], axis=2) index = index + 1 return tf.concat([txt[:, :, 1:], tf.zeros((1, 1, 1), dtype=tf.int32)], axis=2) class MemorizedAttention(keras.layers.Layer): def __init__(self, dim): super().__init__() self.Q = keras.layers.Dense(dim) self.K = keras.layers.Dense(dim) self.V = keras.layers.Dense(dim) self.dim = dim self.mask = Mask() def build(self, input_shape): self.memory_k = self.add_weight( shape=(1, 1, EXPANSION_LENGTH, self.dim), initializer="glorot_uniform", trainable=True, name="memory_k" ) self.memory_v = self.add_weight( shape=(1, 1, EXPANSION_LENGTH, self.dim), initializer="glorot_uniform", trainable=True, name="memory_v" ) def call(self, q, k, v, use_causal=False): repeater = tf.concat([tf.shape(q)[:-2], tf.constant([1, 1], dtype=tf.int32)], axis=-1) memory_k = tf.tile(self.memory_k, repeater) memory_v = tf.tile(self.memory_v, repeater) Q = self.Q(q) K = tf.concat([self.K(k), memory_k], axis=-2) V = tf.concat([self.V(v), memory_v], axis=-2) attn = tf.matmul(Q, K, transpose_b=True) / math.sqrt(TEXT_EMBEDDING_DIM) if use_causal: mask = self.mask(attn) attn += mask attn = tf.nn.softmax(attn, axis=-1) return tf.matmul(attn, V) class MultiheadMemorizedAttention(keras.layers.Layer): def __init__(self): super().__init__() self.heads = [MemorizedAttention(TEXT_EMBEDDING_DIM // ATTENTION_HEAD) for _ in range(ATTENTION_HEAD)] self.norm = keras.layers.LayerNormalization() self.dense = keras.layers.Dense(TEXT_EMBEDDING_DIM) def call(self, q, k, v, use_causal=False): tensors = [head(q, k, v, use_causal=use_causal) for head in self.heads] return self.norm(self.dense(tf.concat(tensors, axis=-1)) + q) class RelativePositionalSelfAttention(keras.layers.Layer): def __init__(self, dim): super().__init__() self.dim = dim self.seq_length = 81 self.positional_k = keras.layers.Embedding( self.seq_length * 2 + 1, self.dim ) def embedding_net(self, Q): seq_length = tf.shape(Q)[-2] return tf.range(seq_length - 1, seq_length * 2 - 1)[..., tf.newaxis] - tf.range(seq_length) def mask(self, attn): mask = tf.range(tf.shape(attn)[-2])[..., tf.newaxis] >= tf.range(tf.shape(attn)[-2]) return tf.where(mask, 0., -1e9) def call(self, Q, K, V, use_causal=False): emb = self.positional_k(self.embedding_net(Q)) eij = tf.matmul(Q, K, transpose_b=True) + tf.einsum('...id, ijd -> ...ij', Q, emb) attn = eij / math.sqrt(TEXT_EMBEDDING_DIM) if use_causal: mask = self.mask(attn) attn += mask attn = tf.nn.softmax(attn, axis=-1) return tf.matmul(attn, V) class MultiheadRelativePositionalSelfAttention(keras.layers.Layer): def __init__(self): super().__init__() self.heads = [RelativePositionalSelfAttention(TEXT_EMBEDDING_DIM // ATTENTION_HEAD) for _ in range(ATTENTION_HEAD)] self.Q = keras.layers.Dense(TEXT_EMBEDDING_DIM) self.K = keras.layers.Dense(TEXT_EMBEDDING_DIM) self.V = keras.layers.Dense(TEXT_EMBEDDING_DIM) def call(self, q, k, v, use_causal=False): Q = self.Q(q) K = self.K(k) V = self.V(v) splitted_Q = tf.split(Q, num_or_size_splits=ATTENTION_HEAD, axis=-1) splitted_K = tf.split(K, num_or_size_splits=ATTENTION_HEAD, axis=-1) splitted_V = tf.split(V, num_or_size_splits=ATTENTION_HEAD, axis=-1) tensors = [head(q, k, v, use_causal=use_causal) for head, q, k, v in zip(self.heads, splitted_Q, splitted_K, splitted_V)] return tf.concat(tensors, axis=-1) class NormalAttention(keras.layers.Layer): def __init__(self, dim): self.dim = dim super().__init__() def mask(self, attn): mask = tf.range(tf.shape(attn)[-2])[..., tf.newaxis] >= tf.range(tf.shape(attn)[-2]) return tf.where(mask, 0., -1e9) def call(self, Q, K, V, use_causal=False): eij = tf.matmul(Q, K, transpose_b=True) attn = eij / math.sqrt(self.dim) if use_causal: mask = self.mask(attn) attn += mask attn = tf.nn.softmax(attn, axis=-1) return tf.matmul(attn, V) class MultiheadAttention(keras.layers.Layer): def __init__(self): super().__init__() self.heads = [NormalAttention(TEXT_EMBEDDING_DIM // ATTENTION_HEAD) for _ in range(ATTENTION_HEAD)] self.Q = keras.layers.Dense(TEXT_EMBEDDING_DIM) self.K = keras.layers.Dense(TEXT_EMBEDDING_DIM) self.V = keras.layers.Dense(TEXT_EMBEDDING_DIM) self.norm = keras.layers.LayerNormalization() def call(self, q, k, v, use_causal=False): Q = self.Q(q) K = self.K(k) V = self.V(v) splitted_Q = tf.split(Q, num_or_size_splits=ATTENTION_HEAD, axis=-1) splitted_K = tf.split(K, num_or_size_splits=ATTENTION_HEAD, axis=-1) splitted_V = tf.split(V, num_or_size_splits=ATTENTION_HEAD, axis=-1) tensors = [head(q, k, v, use_causal=use_causal) for head, q, k, v in zip(self.heads, splitted_Q, splitted_K, splitted_V)] return self.norm(tf.concat(tensors, axis=-1)) class MeshedEncoder(keras.layers.Layer): def __init__(self): super().__init__() self.m_attn = { 0: MultiheadMemorizedAttention, 1: MultiheadRelativePositionalSelfAttention, 2: MultiheadAttention, }[CHOICE]() self.f = keras.Sequential([ keras.layers.Dense(TEXT_EMBEDDING_DIM, activation='relu'), keras.layers.Dense(TEXT_EMBEDDING_DIM) ]) self.norm = keras.layers.LayerNormalization() def call(self, inp): z = self.norm(self.m_attn(inp, inp, inp) + inp) x = self.norm(self.f(z) + z) return x class MeshedDecoder(keras.layers.Layer): def __init__(self): super().__init__() self.sa = SelfAttention() self.ca = Attention() self.dense = keras.layers.Dense(TEXT_EMBEDDING_DIM, activation='sigmoid') self.norm = keras.layers.LayerNormalization() self.f = keras.Sequential([ keras.layers.Dense(TEXT_EMBEDDING_DIM, activation='relu'), keras.layers.Dense(TEXT_EMBEDDING_DIM) ]) def call(self, inp): src, tgts = inp sa = self.norm(self.sa(src)) gated = tf.zeros(tf.shape(sa), dtype=tf.float32) for tgt in tgts: c = self.norm(self.ca(sa, tgt, tgt) + sa) alpha = self.dense(tf.concat([sa, c], axis=-1)) feed = alpha * c gated += feed f = self.norm(self.f(gated) + gated) return self.norm(f) class MultiLayerMeshed(keras.layers.Layer): def __init__(self): super().__init__() self.enc = [MeshedEncoder() for _ in range(MESHED_DEPTH)] self.dec = [MeshedDecoder() for _ in range(MESHED_DEPTH)] def call(self, inp): src, tgt = inp srclst = [tgt] for block in self.enc: srclst.append(block(srclst[-1])) out = src for dec in self.dec: out = dec((out, srclst)) return out class EfficientNetVision(keras.layers.Layer): def __init__(self): super().__init__() self.backbone = keras.applications.EfficientNetB2(include_top=False) self.backbone.trainable = False def call(self, image, **kwargs): processed = keras.applications.efficientnet.preprocess_input(image) return self.backbone(processed) class ResnetVision(keras.layers.Layer): def __init__(self): super().__init__() self.backbone = keras.applications.ResNet50V2(include_top=False) self.backbone.trainable = False def call(self, image, **kwargs): processed = keras.applications.resnet_v2.preprocess_input(image) return self.backbone(processed) class VGGVision(keras.layers.Layer): def __init__(self): super().__init__() self.backbone = keras.applications.VGG16(include_top=False) self.backbone.trainable = False def call(self, image, **kwargs): processed = keras.applications.vgg16.preprocess_input(image) return self.backbone(processed) class Resnet152Vision(keras.layers.Layer): def __init__(self): super().__init__() self.backbone = keras.applications.ResNet152V2(include_top=False) self.backbone.trainable = False def call(self, image, **kwargs): processed = keras.applications.resnet_v2.preprocess_input(image) return self.backbone(processed) class ShortVision(keras.layers.Layer): def __init__(self): super().__init__() self.conv1 = keras.layers.Conv2D(64, 8, strides=(8, 8)) self.conv2 = keras.layers.Conv2D(64, 4, strides=(4, 4)) def call(self, image, **kwargs): image = keras.applications.resnet_v2.preprocess_input(image) return self.conv2(self.conv1(image)) class MeshedFastCaption(keras.Model): def __init__(self): super().__init__() self.vision = { 0: Resnet152Vision, 1: ResnetVision, 2: VGGVision, 3: EfficientNetVision, 4: ShortVision, }[BACKBONE_CHOICE]() self.decoder = MultiLayerMeshed() self.adapt = keras.layers.Dense(TEXT_EMBEDDING_DIM) self.dense = keras.layers.Dense(VOCAB_SIZE, activation='softmax') self.embedding = SeqEmbedding() def call(self, inputs): img, txt = inputs img = self.vision(img) img = self.adapt(img) # (batch, 8, 8, 768) img = tf.reshape(img, [tf.shape(img)[0], tf.shape(img)[1] * tf.shape(img)[2], tf.shape(img)[3]]) seq = self.embedding(txt) # (b, p, length, dim) out = self.decoder((seq, img[:, tf.newaxis])) return self.dense(out) @staticmethod @tf.function def generate_caption(model, img): img = model.vision(img) img = model.adapt(img) img = tf.reshape(img, [tf.shape(img)[0], tf.shape(img)[1] * tf.shape(img)[2], tf.shape(img)[3]]) txt = tf.zeros((1, 1, INPUT_SEQ_LENGTH), dtype=tf.int32) return MeshedFastCaption._fast_generate_from_seed(model, img, txt, tf.constant(0, dtype=tf.int32)) @staticmethod @tf.function def _generate_from_seed(model, img, txt, index): while tf.math.logical_not(tf.math.logical_or(tf.math.reduce_any(tf.equal(txt, 2)), tf.math.reduce_all(tf.not_equal(txt, 0)))): seq = model.embedding(txt) out = model.decoder((seq, img[:, tf.newaxis])) prob = model.dense(out) new_text = tf.argmax(prob, axis=-1, output_type=tf.int32) valid = tf.cast(tf.range(tf.shape(txt)[2]) <= index, dtype=tf.int32) new_text = new_text * valid txt = tf.concat([tf.ones((1, 1, 1), dtype=tf.int32), new_text[:, :, :-1]], axis=2) index = index + 1 return tf.concat([txt[:, :, 1:], tf.zeros((1, 1, 1), dtype=tf.int32)], axis=2) @staticmethod @tf.function def _fast_generate_from_seed(model, img, txt, index): while tf.math.logical_not(tf.math.logical_or(tf.math.reduce_any(tf.equal(txt, 2)), tf.math.reduce_all(tf.not_equal(txt, 0)))): seq = model.embedding(txt) out = model.decoder((seq, img[:, tf.newaxis])) prob = model.dense(out) new_text = tf.argmax(prob, axis=-1, output_type=tf.int32) txt = tf.concat([tf.ones((1, 1, 1), dtype=tf.int32), new_text], axis=2) index = index + 1 return txt[..., 1:]