| | from ..imports import *
|
| |
|
| | class Mask(keras.layers.Layer):
|
| | def __init__(self):
|
| | super().__init__()
|
| |
|
| | def call(self, inputs, **kwargs):
|
| | length = tf.shape(inputs)[-2]
|
| | mask = tf.sequence_mask(tf.range(length + 1)[1:], length)
|
| | return tf.where(mask, 0., -1e9)
|
| |
|
| |
|
| | class SeqEmbedding(keras.layers.Layer):
|
| | def __init__(self):
|
| | super().__init__()
|
| | self.token_emb = keras.layers.Embedding(
|
| | VOCAB_SIZE,
|
| | TEXT_EMBEDDING_DIM)
|
| | self.position = keras.layers.Embedding(
|
| | INPUT_SEQ_LENGTH,
|
| | TEXT_EMBEDDING_DIM
|
| | )
|
| |
|
| | def call(self, txt):
|
| | return self.token_emb(txt) + self.position(tf.range(INPUT_SEQ_LENGTH))[tf.newaxis]
|
| |
|
| |
|
| | class Attention(keras.layers.Layer):
|
| | def __init__(self):
|
| | super().__init__()
|
| | self.Q = keras.layers.Dense(TEXT_EMBEDDING_DIM)
|
| | self.K = keras.layers.Dense(TEXT_EMBEDDING_DIM)
|
| | self.V = keras.layers.Dense(TEXT_EMBEDDING_DIM)
|
| |
|
| | def call(self, q, k, v, mask=None, **kwargs):
|
| | Q = self.Q(q)
|
| | K = self.K(k)
|
| | V = self.V(v)
|
| | attention = tf.matmul(Q, K, transpose_b=True) / (TEXT_EMBEDDING_DIM ** 0.5)
|
| | if mask is not None:
|
| | attention += mask
|
| | attention = tf.nn.softmax(attention, axis=-1)
|
| | return tf.matmul(attention, V)
|
| |
|
| |
|
| | class SelfAttention(keras.layers.Layer):
|
| | def __init__(self):
|
| | super().__init__()
|
| | self.sa = Attention()
|
| | self.norm = keras.layers.LayerNormalization()
|
| | self.mask = Mask()
|
| |
|
| | def call(self, inputs, **kwargs):
|
| | mask = self.mask(inputs)
|
| | return self.norm(self.sa(inputs, inputs, inputs, mask) + inputs)
|
| |
|
| |
|
| | class CrossAttention(keras.layers.Layer):
|
| | def __init__(self):
|
| | super().__init__()
|
| | self.ca = Attention()
|
| | self.norm = keras.layers.LayerNormalization()
|
| |
|
| | def call(self, src, mem, **kwargs):
|
| | return self.norm(self.ca(src, mem, mem) + src)
|
| |
|
| |
|
| | class DecoderLayer(keras.layers.Layer):
|
| | def __init__(self):
|
| | super().__init__()
|
| | self.sa = SelfAttention()
|
| | self.ca = CrossAttention()
|
| | self.ff = FeedForward()
|
| | self.norm = keras.layers.LayerNormalization()
|
| | self.mask = Mask()
|
| |
|
| | def call(self, inp):
|
| | seq, enc = inp
|
| | sa = self.sa(seq)
|
| | ca = self.ca(sa, enc)
|
| | return self.ff(ca)
|
| |
|
| |
|
| | class EncoderLayer(keras.layers.Layer):
|
| | def __init__(self):
|
| | super().__init__()
|
| | self.attn = Attention()
|
| | self.norm = keras.layers.LayerNormalization()
|
| |
|
| | def call(self, inputs):
|
| | out = self.attn(inputs, inputs, inputs) + inputs
|
| | return self.norm(out)
|
| |
|
| |
|
| | class FeedForward(keras.layers.Layer):
|
| | def __init__(self):
|
| | super().__init__()
|
| | self.seq = keras.Sequential([
|
| | keras.layers.Dense(TEXT_EMBEDDING_DIM * 2, activation='relu'),
|
| | keras.layers.Dense(TEXT_EMBEDDING_DIM),
|
| | keras.layers.Dropout(0.3),
|
| | ])
|
| | self.norm = keras.layers.LayerNormalization()
|
| |
|
| | def call(self, inputs, **kwargs):
|
| | return self.norm(self.seq(inputs) + inputs)
|
| |
|
| |
|
| | class FastCaption(keras.Model):
|
| | def __init__(self):
|
| | super().__init__()
|
| | self.backbone = keras.applications.VGG16(include_top=False)
|
| | self.backbone.trainable = False
|
| | self.decoder = DecoderLayer()
|
| | self.encoder = EncoderLayer()
|
| | self.adapt = keras.layers.Dense(TEXT_EMBEDDING_DIM)
|
| | self.dense = keras.layers.Dense(VOCAB_SIZE, activation='softmax')
|
| | self.embedding = SeqEmbedding()
|
| |
|
| | def call(self, inputs):
|
| | img, txt = inputs
|
| | img = keras.applications.vgg16.preprocess_input(img)
|
| | img = self.backbone(img)
|
| | img = self.adapt(img)
|
| | img = tf.reshape(img, [tf.shape(img)[0], tf.shape(img)[1] * tf.shape(img)[2], tf.shape(img)[3]])
|
| | img = self.encoder(img)
|
| | seq = self.embedding(txt)
|
| | out = self.decoder((seq, img[:, tf.newaxis]))
|
| | return self.dense(out)
|
| |
|
| | @staticmethod
|
| | @tf.function
|
| | def generate_caption(model, img):
|
| | img = keras.applications.vgg16.preprocess_input(img)
|
| | img = model.backbone(img)
|
| | img = model.adapt(img)
|
| | img = tf.reshape(img, [tf.shape(img)[0], tf.shape(img)[1] * tf.shape(img)[2], tf.shape(img)[3]])
|
| | img = model.encoder(img)
|
| | txt = tf.zeros((1, 1, INPUT_SEQ_LENGTH), dtype=tf.int32)
|
| | return FastCaption._generate_from_seed(model, img, txt, tf.constant(0, dtype=tf.int32))
|
| |
|
| | @staticmethod
|
| | @tf.function
|
| | def _generate_from_seed(model, img, txt, index):
|
| | while tf.math.logical_not(tf.math.logical_or(tf.math.reduce_any(tf.equal(txt, 2)), tf.math.reduce_all(tf.not_equal(txt, 0)))):
|
| | seq = model.embedding(txt)
|
| | out = model.decoder((seq, img))
|
| | prob = model.dense(out)
|
| | new_text = tf.argmax(prob, axis=-1, output_type=tf.int32)
|
| | valid = tf.cast(tf.range(tf.shape(txt)[2]) <= index, dtype=tf.int32)
|
| | new_text = new_text * valid
|
| | txt = tf.concat([tf.ones((1, 1, 1), dtype=tf.int32), new_text[:, :, :-1]], axis=2)
|
| | index = index + 1
|
| | return tf.concat([txt[:, :, 1:], tf.zeros((1, 1, 1), dtype=tf.int32)], axis=2)
|
| |
|
| |
|
| | class MemorizedAttention(keras.layers.Layer):
|
| | def __init__(self, dim):
|
| | super().__init__()
|
| | self.Q = keras.layers.Dense(dim)
|
| | self.K = keras.layers.Dense(dim)
|
| | self.V = keras.layers.Dense(dim)
|
| | self.dim = dim
|
| | self.mask = Mask()
|
| |
|
| |
|
| | def build(self, input_shape):
|
| | self.memory_k = self.add_weight(
|
| | shape=(1, 1, EXPANSION_LENGTH, self.dim),
|
| | initializer="glorot_uniform",
|
| | trainable=True,
|
| | name="memory_k"
|
| | )
|
| | self.memory_v = self.add_weight(
|
| | shape=(1, 1, EXPANSION_LENGTH, self.dim),
|
| | initializer="glorot_uniform",
|
| | trainable=True,
|
| | name="memory_v"
|
| | )
|
| |
|
| | def call(self, q, k, v, use_causal=False):
|
| | repeater = tf.concat([tf.shape(q)[:-2], tf.constant([1, 1], dtype=tf.int32)], axis=-1)
|
| | memory_k = tf.tile(self.memory_k, repeater)
|
| | memory_v = tf.tile(self.memory_v, repeater)
|
| | Q = self.Q(q)
|
| | K = tf.concat([self.K(k), memory_k], axis=-2)
|
| | V = tf.concat([self.V(v), memory_v], axis=-2)
|
| | attn = tf.matmul(Q, K, transpose_b=True) / math.sqrt(TEXT_EMBEDDING_DIM)
|
| | if use_causal:
|
| | mask = self.mask(attn)
|
| | attn += mask
|
| | attn = tf.nn.softmax(attn, axis=-1)
|
| | return tf.matmul(attn, V)
|
| |
|
| |
|
| | class MultiheadMemorizedAttention(keras.layers.Layer):
|
| | def __init__(self):
|
| | super().__init__()
|
| | self.heads = [MemorizedAttention(TEXT_EMBEDDING_DIM // ATTENTION_HEAD) for _ in range(ATTENTION_HEAD)]
|
| | self.norm = keras.layers.LayerNormalization()
|
| | self.dense = keras.layers.Dense(TEXT_EMBEDDING_DIM)
|
| |
|
| | def call(self, q, k, v, use_causal=False):
|
| | tensors = [head(q, k, v, use_causal=use_causal) for head in self.heads]
|
| | return self.norm(self.dense(tf.concat(tensors, axis=-1)) + q)
|
| |
|
| |
|
| | class RelativePositionalSelfAttention(keras.layers.Layer):
|
| | def __init__(self, dim):
|
| | super().__init__()
|
| | self.dim = dim
|
| | self.seq_length = 81
|
| | self.positional_k = keras.layers.Embedding(
|
| | self.seq_length * 2 + 1,
|
| | self.dim
|
| | )
|
| |
|
| | def embedding_net(self, Q):
|
| | seq_length = tf.shape(Q)[-2]
|
| | return tf.range(seq_length - 1, seq_length * 2 - 1)[..., tf.newaxis] - tf.range(seq_length)
|
| |
|
| | def mask(self, attn):
|
| | mask = tf.range(tf.shape(attn)[-2])[..., tf.newaxis] >= tf.range(tf.shape(attn)[-2])
|
| | return tf.where(mask, 0., -1e9)
|
| |
|
| | def call(self, Q, K, V, use_causal=False):
|
| | emb = self.positional_k(self.embedding_net(Q))
|
| | eij = tf.matmul(Q, K, transpose_b=True) + tf.einsum('...id, ijd -> ...ij', Q, emb)
|
| | attn = eij / math.sqrt(TEXT_EMBEDDING_DIM)
|
| | if use_causal:
|
| | mask = self.mask(attn)
|
| | attn += mask
|
| | attn = tf.nn.softmax(attn, axis=-1)
|
| | return tf.matmul(attn, V)
|
| |
|
| |
|
| | class MultiheadRelativePositionalSelfAttention(keras.layers.Layer):
|
| | def __init__(self):
|
| | super().__init__()
|
| | self.heads = [RelativePositionalSelfAttention(TEXT_EMBEDDING_DIM // ATTENTION_HEAD) for _ in range(ATTENTION_HEAD)]
|
| | self.Q = keras.layers.Dense(TEXT_EMBEDDING_DIM)
|
| | self.K = keras.layers.Dense(TEXT_EMBEDDING_DIM)
|
| | self.V = keras.layers.Dense(TEXT_EMBEDDING_DIM)
|
| |
|
| | def call(self, q, k, v, use_causal=False):
|
| | Q = self.Q(q)
|
| | K = self.K(k)
|
| | V = self.V(v)
|
| | splitted_Q = tf.split(Q, num_or_size_splits=ATTENTION_HEAD, axis=-1)
|
| | splitted_K = tf.split(K, num_or_size_splits=ATTENTION_HEAD, axis=-1)
|
| | splitted_V = tf.split(V, num_or_size_splits=ATTENTION_HEAD, axis=-1)
|
| | tensors = [head(q, k, v, use_causal=use_causal) for head, q, k, v in zip(self.heads, splitted_Q, splitted_K, splitted_V)]
|
| | return tf.concat(tensors, axis=-1)
|
| |
|
| |
|
| | class NormalAttention(keras.layers.Layer):
|
| | def __init__(self, dim):
|
| | self.dim = dim
|
| | super().__init__()
|
| |
|
| | def mask(self, attn):
|
| | mask = tf.range(tf.shape(attn)[-2])[..., tf.newaxis] >= tf.range(tf.shape(attn)[-2])
|
| | return tf.where(mask, 0., -1e9)
|
| |
|
| | def call(self, Q, K, V, use_causal=False):
|
| | eij = tf.matmul(Q, K, transpose_b=True)
|
| | attn = eij / math.sqrt(self.dim)
|
| | if use_causal:
|
| | mask = self.mask(attn)
|
| | attn += mask
|
| | attn = tf.nn.softmax(attn, axis=-1)
|
| | return tf.matmul(attn, V)
|
| |
|
| |
|
| | class MultiheadAttention(keras.layers.Layer):
|
| | def __init__(self):
|
| | super().__init__()
|
| | self.heads = [NormalAttention(TEXT_EMBEDDING_DIM // ATTENTION_HEAD) for _ in range(ATTENTION_HEAD)]
|
| | self.Q = keras.layers.Dense(TEXT_EMBEDDING_DIM)
|
| | self.K = keras.layers.Dense(TEXT_EMBEDDING_DIM)
|
| | self.V = keras.layers.Dense(TEXT_EMBEDDING_DIM)
|
| | self.norm = keras.layers.LayerNormalization()
|
| |
|
| | def call(self, q, k, v, use_causal=False):
|
| | Q = self.Q(q)
|
| | K = self.K(k)
|
| | V = self.V(v)
|
| | splitted_Q = tf.split(Q, num_or_size_splits=ATTENTION_HEAD, axis=-1)
|
| | splitted_K = tf.split(K, num_or_size_splits=ATTENTION_HEAD, axis=-1)
|
| | splitted_V = tf.split(V, num_or_size_splits=ATTENTION_HEAD, axis=-1)
|
| | tensors = [head(q, k, v, use_causal=use_causal) for head, q, k, v in zip(self.heads, splitted_Q, splitted_K, splitted_V)]
|
| | return self.norm(tf.concat(tensors, axis=-1))
|
| |
|
| |
|
| | class MeshedEncoder(keras.layers.Layer):
|
| | def __init__(self):
|
| | super().__init__()
|
| | self.m_attn = {
|
| | 0: MultiheadMemorizedAttention,
|
| | 1: MultiheadRelativePositionalSelfAttention,
|
| | 2: MultiheadAttention,
|
| | }[CHOICE]()
|
| | self.f = keras.Sequential([
|
| | keras.layers.Dense(TEXT_EMBEDDING_DIM, activation='relu'),
|
| | keras.layers.Dense(TEXT_EMBEDDING_DIM)
|
| | ])
|
| | self.norm = keras.layers.LayerNormalization()
|
| |
|
| | def call(self, inp):
|
| | z = self.norm(self.m_attn(inp, inp, inp) + inp)
|
| | x = self.norm(self.f(z) + z)
|
| | return x
|
| |
|
| |
|
| | class MeshedDecoder(keras.layers.Layer):
|
| | def __init__(self):
|
| | super().__init__()
|
| | self.sa = SelfAttention()
|
| | self.ca = Attention()
|
| | self.dense = keras.layers.Dense(TEXT_EMBEDDING_DIM, activation='sigmoid')
|
| | self.norm = keras.layers.LayerNormalization()
|
| | self.f = keras.Sequential([
|
| | keras.layers.Dense(TEXT_EMBEDDING_DIM, activation='relu'),
|
| | keras.layers.Dense(TEXT_EMBEDDING_DIM)
|
| | ])
|
| |
|
| | def call(self, inp):
|
| | src, tgts = inp
|
| | sa = self.norm(self.sa(src))
|
| | gated = tf.zeros(tf.shape(sa), dtype=tf.float32)
|
| | for tgt in tgts:
|
| | c = self.norm(self.ca(sa, tgt, tgt) + sa)
|
| | alpha = self.dense(tf.concat([sa, c], axis=-1))
|
| | feed = alpha * c
|
| | gated += feed
|
| | f = self.norm(self.f(gated) + gated)
|
| | return self.norm(f)
|
| |
|
| |
|
| | class MultiLayerMeshed(keras.layers.Layer):
|
| | def __init__(self):
|
| | super().__init__()
|
| | self.enc = [MeshedEncoder() for _ in range(MESHED_DEPTH)]
|
| | self.dec = [MeshedDecoder() for _ in range(MESHED_DEPTH)]
|
| |
|
| | def call(self, inp):
|
| | src, tgt = inp
|
| | srclst = [tgt]
|
| | for block in self.enc:
|
| | srclst.append(block(srclst[-1]))
|
| | out = src
|
| | for dec in self.dec:
|
| | out = dec((out, srclst))
|
| | return out
|
| |
|
| |
|
| | class EfficientNetVision(keras.layers.Layer):
|
| | def __init__(self):
|
| | super().__init__()
|
| | self.backbone = keras.applications.EfficientNetB2(include_top=False)
|
| | self.backbone.trainable = False
|
| |
|
| | def call(self, image, **kwargs):
|
| | processed = keras.applications.efficientnet.preprocess_input(image)
|
| | return self.backbone(processed)
|
| |
|
| |
|
| | class ResnetVision(keras.layers.Layer):
|
| | def __init__(self):
|
| | super().__init__()
|
| | self.backbone = keras.applications.ResNet50V2(include_top=False)
|
| | self.backbone.trainable = False
|
| |
|
| | def call(self, image, **kwargs):
|
| | processed = keras.applications.resnet_v2.preprocess_input(image)
|
| | return self.backbone(processed)
|
| |
|
| |
|
| | class VGGVision(keras.layers.Layer):
|
| | def __init__(self):
|
| | super().__init__()
|
| | self.backbone = keras.applications.VGG16(include_top=False)
|
| | self.backbone.trainable = False
|
| |
|
| | def call(self, image, **kwargs):
|
| | processed = keras.applications.vgg16.preprocess_input(image)
|
| | return self.backbone(processed)
|
| |
|
| |
|
| | class Resnet152Vision(keras.layers.Layer):
|
| | def __init__(self):
|
| | super().__init__()
|
| | self.backbone = keras.applications.ResNet152V2(include_top=False)
|
| | self.backbone.trainable = False
|
| |
|
| | def call(self, image, **kwargs):
|
| | processed = keras.applications.resnet_v2.preprocess_input(image)
|
| | return self.backbone(processed)
|
| |
|
| |
|
| | class ShortVision(keras.layers.Layer):
|
| | def __init__(self):
|
| | super().__init__()
|
| | self.conv1 = keras.layers.Conv2D(64, 8, strides=(8, 8))
|
| | self.conv2 = keras.layers.Conv2D(64, 4, strides=(4, 4))
|
| |
|
| | def call(self, image, **kwargs):
|
| | image = keras.applications.resnet_v2.preprocess_input(image)
|
| | return self.conv2(self.conv1(image))
|
| |
|
| |
|
| | class MeshedFastCaption(keras.Model):
|
| | def __init__(self):
|
| | super().__init__()
|
| | self.vision = {
|
| | 0: Resnet152Vision,
|
| | 1: ResnetVision,
|
| | 2: VGGVision,
|
| | 3: EfficientNetVision,
|
| | 4: ShortVision,
|
| | }[BACKBONE_CHOICE]()
|
| | self.decoder = MultiLayerMeshed()
|
| | self.adapt = keras.layers.Dense(TEXT_EMBEDDING_DIM)
|
| | self.dense = keras.layers.Dense(VOCAB_SIZE, activation='softmax')
|
| | self.embedding = SeqEmbedding()
|
| |
|
| | def call(self, inputs):
|
| | img, txt = inputs
|
| | img = self.vision(img)
|
| | img = self.adapt(img)
|
| | img = tf.reshape(img, [tf.shape(img)[0], tf.shape(img)[1] * tf.shape(img)[2], tf.shape(img)[3]])
|
| | seq = self.embedding(txt)
|
| | out = self.decoder((seq, img[:, tf.newaxis]))
|
| | return self.dense(out)
|
| |
|
| | @staticmethod
|
| | @tf.function
|
| | def generate_caption(model, img):
|
| | img = model.vision(img)
|
| | img = model.adapt(img)
|
| | img = tf.reshape(img, [tf.shape(img)[0], tf.shape(img)[1] * tf.shape(img)[2], tf.shape(img)[3]])
|
| | txt = tf.zeros((1, 1, INPUT_SEQ_LENGTH), dtype=tf.int32)
|
| | return MeshedFastCaption._fast_generate_from_seed(model, img, txt, tf.constant(0, dtype=tf.int32))
|
| |
|
| | @staticmethod
|
| | @tf.function
|
| | def _generate_from_seed(model, img, txt, index):
|
| | while tf.math.logical_not(tf.math.logical_or(tf.math.reduce_any(tf.equal(txt, 2)), tf.math.reduce_all(tf.not_equal(txt, 0)))):
|
| | seq = model.embedding(txt)
|
| | out = model.decoder((seq, img[:, tf.newaxis]))
|
| | prob = model.dense(out)
|
| | new_text = tf.argmax(prob, axis=-1, output_type=tf.int32)
|
| | valid = tf.cast(tf.range(tf.shape(txt)[2]) <= index, dtype=tf.int32)
|
| | new_text = new_text * valid
|
| | txt = tf.concat([tf.ones((1, 1, 1), dtype=tf.int32), new_text[:, :, :-1]], axis=2)
|
| | index = index + 1
|
| | return tf.concat([txt[:, :, 1:], tf.zeros((1, 1, 1), dtype=tf.int32)], axis=2)
|
| |
|
| | @staticmethod
|
| | @tf.function
|
| | def _fast_generate_from_seed(model, img, txt, index):
|
| | while tf.math.logical_not(tf.math.logical_or(tf.math.reduce_any(tf.equal(txt, 2)), tf.math.reduce_all(tf.not_equal(txt, 0)))):
|
| | seq = model.embedding(txt)
|
| | out = model.decoder((seq, img[:, tf.newaxis]))
|
| | prob = model.dense(out)
|
| | new_text = tf.argmax(prob, axis=-1, output_type=tf.int32)
|
| | txt = tf.concat([tf.ones((1, 1, 1), dtype=tf.int32), new_text], axis=2)
|
| | index = index + 1
|
| | return txt[..., 1:] |