KhangTruong's picture
Super-squash branch 'main' using huggingface_hub
6f31f53 verified
from ..imports import *
class Mask(keras.layers.Layer):
def __init__(self):
super().__init__()
def call(self, inputs, **kwargs):
length = tf.shape(inputs)[-2]
mask = tf.sequence_mask(tf.range(length + 1)[1:], length)
return tf.where(mask, 0., -1e9)
class SeqEmbedding(keras.layers.Layer):
def __init__(self):
super().__init__()
self.token_emb = keras.layers.Embedding(
VOCAB_SIZE,
TEXT_EMBEDDING_DIM)
self.position = keras.layers.Embedding(
INPUT_SEQ_LENGTH,
TEXT_EMBEDDING_DIM
)
def call(self, txt):
return self.token_emb(txt) + self.position(tf.range(INPUT_SEQ_LENGTH))[tf.newaxis]
class Attention(keras.layers.Layer):
def __init__(self):
super().__init__()
self.Q = keras.layers.Dense(TEXT_EMBEDDING_DIM)
self.K = keras.layers.Dense(TEXT_EMBEDDING_DIM)
self.V = keras.layers.Dense(TEXT_EMBEDDING_DIM)
def call(self, q, k, v, mask=None, **kwargs):
Q = self.Q(q)
K = self.K(k)
V = self.V(v)
attention = tf.matmul(Q, K, transpose_b=True) / (TEXT_EMBEDDING_DIM ** 0.5)
if mask is not None:
attention += mask
attention = tf.nn.softmax(attention, axis=-1)
return tf.matmul(attention, V)
class SelfAttention(keras.layers.Layer):
def __init__(self):
super().__init__()
self.sa = Attention()
self.norm = keras.layers.LayerNormalization()
self.mask = Mask()
def call(self, inputs, **kwargs):
mask = self.mask(inputs)
return self.norm(self.sa(inputs, inputs, inputs, mask) + inputs)
class CrossAttention(keras.layers.Layer):
def __init__(self):
super().__init__()
self.ca = Attention()
self.norm = keras.layers.LayerNormalization()
def call(self, src, mem, **kwargs):
return self.norm(self.ca(src, mem, mem) + src)
class DecoderLayer(keras.layers.Layer):
def __init__(self):
super().__init__()
self.sa = SelfAttention()
self.ca = CrossAttention()
self.ff = FeedForward()
self.norm = keras.layers.LayerNormalization()
self.mask = Mask()
def call(self, inp):
seq, enc = inp
sa = self.sa(seq)
ca = self.ca(sa, enc)
return self.ff(ca)
class EncoderLayer(keras.layers.Layer):
def __init__(self):
super().__init__()
self.attn = Attention()
self.norm = keras.layers.LayerNormalization()
def call(self, inputs):
out = self.attn(inputs, inputs, inputs) + inputs
return self.norm(out)
class FeedForward(keras.layers.Layer):
def __init__(self):
super().__init__()
self.seq = keras.Sequential([
keras.layers.Dense(TEXT_EMBEDDING_DIM * 2, activation='relu'),
keras.layers.Dense(TEXT_EMBEDDING_DIM),
keras.layers.Dropout(0.3),
])
self.norm = keras.layers.LayerNormalization()
def call(self, inputs, **kwargs):
return self.norm(self.seq(inputs) + inputs)
class FastCaption(keras.Model):
def __init__(self):
super().__init__()
self.backbone = keras.applications.VGG16(include_top=False)
self.backbone.trainable = False
self.decoder = DecoderLayer()
self.encoder = EncoderLayer()
self.adapt = keras.layers.Dense(TEXT_EMBEDDING_DIM)
self.dense = keras.layers.Dense(VOCAB_SIZE, activation='softmax')
self.embedding = SeqEmbedding()
def call(self, inputs):
img, txt = inputs
img = keras.applications.vgg16.preprocess_input(img)
img = self.backbone(img) # (batch, 8, 8, 2048)
img = self.adapt(img) # (batch, 8, 8, 768)
img = tf.reshape(img, [tf.shape(img)[0], tf.shape(img)[1] * tf.shape(img)[2], tf.shape(img)[3]])
img = self.encoder(img)
seq = self.embedding(txt) # (b, p, length, dim)
out = self.decoder((seq, img[:, tf.newaxis]))
return self.dense(out)
@staticmethod
@tf.function
def generate_caption(model, img):
img = keras.applications.vgg16.preprocess_input(img)
img = model.backbone(img)
img = model.adapt(img)
img = tf.reshape(img, [tf.shape(img)[0], tf.shape(img)[1] * tf.shape(img)[2], tf.shape(img)[3]])
img = model.encoder(img)
txt = tf.zeros((1, 1, INPUT_SEQ_LENGTH), dtype=tf.int32)
return FastCaption._generate_from_seed(model, img, txt, tf.constant(0, dtype=tf.int32))
@staticmethod
@tf.function
def _generate_from_seed(model, img, txt, index):
while tf.math.logical_not(tf.math.logical_or(tf.math.reduce_any(tf.equal(txt, 2)), tf.math.reduce_all(tf.not_equal(txt, 0)))):
seq = model.embedding(txt)
out = model.decoder((seq, img))
prob = model.dense(out)
new_text = tf.argmax(prob, axis=-1, output_type=tf.int32)
valid = tf.cast(tf.range(tf.shape(txt)[2]) <= index, dtype=tf.int32)
new_text = new_text * valid
txt = tf.concat([tf.ones((1, 1, 1), dtype=tf.int32), new_text[:, :, :-1]], axis=2)
index = index + 1
return tf.concat([txt[:, :, 1:], tf.zeros((1, 1, 1), dtype=tf.int32)], axis=2)
class MemorizedAttention(keras.layers.Layer):
def __init__(self, dim):
super().__init__()
self.Q = keras.layers.Dense(dim)
self.K = keras.layers.Dense(dim)
self.V = keras.layers.Dense(dim)
self.dim = dim
self.mask = Mask()
def build(self, input_shape):
self.memory_k = self.add_weight(
shape=(1, 1, EXPANSION_LENGTH, self.dim),
initializer="glorot_uniform",
trainable=True,
name="memory_k"
)
self.memory_v = self.add_weight(
shape=(1, 1, EXPANSION_LENGTH, self.dim),
initializer="glorot_uniform",
trainable=True,
name="memory_v"
)
def call(self, q, k, v, use_causal=False):
repeater = tf.concat([tf.shape(q)[:-2], tf.constant([1, 1], dtype=tf.int32)], axis=-1)
memory_k = tf.tile(self.memory_k, repeater)
memory_v = tf.tile(self.memory_v, repeater)
Q = self.Q(q)
K = tf.concat([self.K(k), memory_k], axis=-2)
V = tf.concat([self.V(v), memory_v], axis=-2)
attn = tf.matmul(Q, K, transpose_b=True) / math.sqrt(TEXT_EMBEDDING_DIM)
if use_causal:
mask = self.mask(attn)
attn += mask
attn = tf.nn.softmax(attn, axis=-1)
return tf.matmul(attn, V)
class MultiheadMemorizedAttention(keras.layers.Layer):
def __init__(self):
super().__init__()
self.heads = [MemorizedAttention(TEXT_EMBEDDING_DIM // ATTENTION_HEAD) for _ in range(ATTENTION_HEAD)]
self.norm = keras.layers.LayerNormalization()
self.dense = keras.layers.Dense(TEXT_EMBEDDING_DIM)
def call(self, q, k, v, use_causal=False):
tensors = [head(q, k, v, use_causal=use_causal) for head in self.heads]
return self.norm(self.dense(tf.concat(tensors, axis=-1)) + q)
class RelativePositionalSelfAttention(keras.layers.Layer):
def __init__(self, dim):
super().__init__()
self.dim = dim
self.seq_length = 81
self.positional_k = keras.layers.Embedding(
self.seq_length * 2 + 1,
self.dim
)
def embedding_net(self, Q):
seq_length = tf.shape(Q)[-2]
return tf.range(seq_length - 1, seq_length * 2 - 1)[..., tf.newaxis] - tf.range(seq_length)
def mask(self, attn):
mask = tf.range(tf.shape(attn)[-2])[..., tf.newaxis] >= tf.range(tf.shape(attn)[-2])
return tf.where(mask, 0., -1e9)
def call(self, Q, K, V, use_causal=False):
emb = self.positional_k(self.embedding_net(Q))
eij = tf.matmul(Q, K, transpose_b=True) + tf.einsum('...id, ijd -> ...ij', Q, emb)
attn = eij / math.sqrt(TEXT_EMBEDDING_DIM)
if use_causal:
mask = self.mask(attn)
attn += mask
attn = tf.nn.softmax(attn, axis=-1)
return tf.matmul(attn, V)
class MultiheadRelativePositionalSelfAttention(keras.layers.Layer):
def __init__(self):
super().__init__()
self.heads = [RelativePositionalSelfAttention(TEXT_EMBEDDING_DIM // ATTENTION_HEAD) for _ in range(ATTENTION_HEAD)]
self.Q = keras.layers.Dense(TEXT_EMBEDDING_DIM)
self.K = keras.layers.Dense(TEXT_EMBEDDING_DIM)
self.V = keras.layers.Dense(TEXT_EMBEDDING_DIM)
def call(self, q, k, v, use_causal=False):
Q = self.Q(q)
K = self.K(k)
V = self.V(v)
splitted_Q = tf.split(Q, num_or_size_splits=ATTENTION_HEAD, axis=-1)
splitted_K = tf.split(K, num_or_size_splits=ATTENTION_HEAD, axis=-1)
splitted_V = tf.split(V, num_or_size_splits=ATTENTION_HEAD, axis=-1)
tensors = [head(q, k, v, use_causal=use_causal) for head, q, k, v in zip(self.heads, splitted_Q, splitted_K, splitted_V)]
return tf.concat(tensors, axis=-1)
class NormalAttention(keras.layers.Layer):
def __init__(self, dim):
self.dim = dim
super().__init__()
def mask(self, attn):
mask = tf.range(tf.shape(attn)[-2])[..., tf.newaxis] >= tf.range(tf.shape(attn)[-2])
return tf.where(mask, 0., -1e9)
def call(self, Q, K, V, use_causal=False):
eij = tf.matmul(Q, K, transpose_b=True)
attn = eij / math.sqrt(self.dim)
if use_causal:
mask = self.mask(attn)
attn += mask
attn = tf.nn.softmax(attn, axis=-1)
return tf.matmul(attn, V)
class MultiheadAttention(keras.layers.Layer):
def __init__(self):
super().__init__()
self.heads = [NormalAttention(TEXT_EMBEDDING_DIM // ATTENTION_HEAD) for _ in range(ATTENTION_HEAD)]
self.Q = keras.layers.Dense(TEXT_EMBEDDING_DIM)
self.K = keras.layers.Dense(TEXT_EMBEDDING_DIM)
self.V = keras.layers.Dense(TEXT_EMBEDDING_DIM)
self.norm = keras.layers.LayerNormalization()
def call(self, q, k, v, use_causal=False):
Q = self.Q(q)
K = self.K(k)
V = self.V(v)
splitted_Q = tf.split(Q, num_or_size_splits=ATTENTION_HEAD, axis=-1)
splitted_K = tf.split(K, num_or_size_splits=ATTENTION_HEAD, axis=-1)
splitted_V = tf.split(V, num_or_size_splits=ATTENTION_HEAD, axis=-1)
tensors = [head(q, k, v, use_causal=use_causal) for head, q, k, v in zip(self.heads, splitted_Q, splitted_K, splitted_V)]
return self.norm(tf.concat(tensors, axis=-1))
class MeshedEncoder(keras.layers.Layer):
def __init__(self):
super().__init__()
self.m_attn = {
0: MultiheadMemorizedAttention,
1: MultiheadRelativePositionalSelfAttention,
2: MultiheadAttention,
}[CHOICE]()
self.f = keras.Sequential([
keras.layers.Dense(TEXT_EMBEDDING_DIM, activation='relu'),
keras.layers.Dense(TEXT_EMBEDDING_DIM)
])
self.norm = keras.layers.LayerNormalization()
def call(self, inp):
z = self.norm(self.m_attn(inp, inp, inp) + inp)
x = self.norm(self.f(z) + z)
return x
class MeshedDecoder(keras.layers.Layer):
def __init__(self):
super().__init__()
self.sa = SelfAttention()
self.ca = Attention()
self.dense = keras.layers.Dense(TEXT_EMBEDDING_DIM, activation='sigmoid')
self.norm = keras.layers.LayerNormalization()
self.f = keras.Sequential([
keras.layers.Dense(TEXT_EMBEDDING_DIM, activation='relu'),
keras.layers.Dense(TEXT_EMBEDDING_DIM)
])
def call(self, inp):
src, tgts = inp
sa = self.norm(self.sa(src))
gated = tf.zeros(tf.shape(sa), dtype=tf.float32)
for tgt in tgts:
c = self.norm(self.ca(sa, tgt, tgt) + sa)
alpha = self.dense(tf.concat([sa, c], axis=-1))
feed = alpha * c
gated += feed
f = self.norm(self.f(gated) + gated)
return self.norm(f)
class MultiLayerMeshed(keras.layers.Layer):
def __init__(self):
super().__init__()
self.enc = [MeshedEncoder() for _ in range(MESHED_DEPTH)]
self.dec = [MeshedDecoder() for _ in range(MESHED_DEPTH)]
def call(self, inp):
src, tgt = inp
srclst = [tgt]
for block in self.enc:
srclst.append(block(srclst[-1]))
out = src
for dec in self.dec:
out = dec((out, srclst))
return out
class EfficientNetVision(keras.layers.Layer):
def __init__(self):
super().__init__()
self.backbone = keras.applications.EfficientNetB2(include_top=False)
self.backbone.trainable = False
def call(self, image, **kwargs):
processed = keras.applications.efficientnet.preprocess_input(image)
return self.backbone(processed)
class ResnetVision(keras.layers.Layer):
def __init__(self):
super().__init__()
self.backbone = keras.applications.ResNet50V2(include_top=False)
self.backbone.trainable = False
def call(self, image, **kwargs):
processed = keras.applications.resnet_v2.preprocess_input(image)
return self.backbone(processed)
class VGGVision(keras.layers.Layer):
def __init__(self):
super().__init__()
self.backbone = keras.applications.VGG16(include_top=False)
self.backbone.trainable = False
def call(self, image, **kwargs):
processed = keras.applications.vgg16.preprocess_input(image)
return self.backbone(processed)
class Resnet152Vision(keras.layers.Layer):
def __init__(self):
super().__init__()
self.backbone = keras.applications.ResNet152V2(include_top=False)
self.backbone.trainable = False
def call(self, image, **kwargs):
processed = keras.applications.resnet_v2.preprocess_input(image)
return self.backbone(processed)
class ShortVision(keras.layers.Layer):
def __init__(self):
super().__init__()
self.conv1 = keras.layers.Conv2D(64, 8, strides=(8, 8))
self.conv2 = keras.layers.Conv2D(64, 4, strides=(4, 4))
def call(self, image, **kwargs):
image = keras.applications.resnet_v2.preprocess_input(image)
return self.conv2(self.conv1(image))
class MeshedFastCaption(keras.Model):
def __init__(self):
super().__init__()
self.vision = {
0: Resnet152Vision,
1: ResnetVision,
2: VGGVision,
3: EfficientNetVision,
4: ShortVision,
}[BACKBONE_CHOICE]()
self.decoder = MultiLayerMeshed()
self.adapt = keras.layers.Dense(TEXT_EMBEDDING_DIM)
self.dense = keras.layers.Dense(VOCAB_SIZE, activation='softmax')
self.embedding = SeqEmbedding()
def call(self, inputs):
img, txt = inputs
img = self.vision(img)
img = self.adapt(img) # (batch, 8, 8, 768)
img = tf.reshape(img, [tf.shape(img)[0], tf.shape(img)[1] * tf.shape(img)[2], tf.shape(img)[3]])
seq = self.embedding(txt) # (b, p, length, dim)
out = self.decoder((seq, img[:, tf.newaxis]))
return self.dense(out)
@staticmethod
@tf.function
def generate_caption(model, img):
img = model.vision(img)
img = model.adapt(img)
img = tf.reshape(img, [tf.shape(img)[0], tf.shape(img)[1] * tf.shape(img)[2], tf.shape(img)[3]])
txt = tf.zeros((1, 1, INPUT_SEQ_LENGTH), dtype=tf.int32)
return MeshedFastCaption._fast_generate_from_seed(model, img, txt, tf.constant(0, dtype=tf.int32))
@staticmethod
@tf.function
def _generate_from_seed(model, img, txt, index):
while tf.math.logical_not(tf.math.logical_or(tf.math.reduce_any(tf.equal(txt, 2)), tf.math.reduce_all(tf.not_equal(txt, 0)))):
seq = model.embedding(txt)
out = model.decoder((seq, img[:, tf.newaxis]))
prob = model.dense(out)
new_text = tf.argmax(prob, axis=-1, output_type=tf.int32)
valid = tf.cast(tf.range(tf.shape(txt)[2]) <= index, dtype=tf.int32)
new_text = new_text * valid
txt = tf.concat([tf.ones((1, 1, 1), dtype=tf.int32), new_text[:, :, :-1]], axis=2)
index = index + 1
return tf.concat([txt[:, :, 1:], tf.zeros((1, 1, 1), dtype=tf.int32)], axis=2)
@staticmethod
@tf.function
def _fast_generate_from_seed(model, img, txt, index):
while tf.math.logical_not(tf.math.logical_or(tf.math.reduce_any(tf.equal(txt, 2)), tf.math.reduce_all(tf.not_equal(txt, 0)))):
seq = model.embedding(txt)
out = model.decoder((seq, img[:, tf.newaxis]))
prob = model.dense(out)
new_text = tf.argmax(prob, axis=-1, output_type=tf.int32)
txt = tf.concat([tf.ones((1, 1, 1), dtype=tf.int32), new_text], axis=2)
index = index + 1
return txt[..., 1:]