|
|
from transformers import PreTrainedTokenizerFast |
|
|
|
|
|
token2vec = PreTrainedTokenizerFast.from_pretrained("/repository/bpe") |
|
|
|
|
|
from typing import Dict, Any |
|
|
import numpy as np |
|
|
import tensorflow as tf |
|
|
from tensorflow.keras.layers import Dense, LayerNormalization, Conv2D, UpSampling2D, Embedding, MultiHeadAttention |
|
|
from tensorflow.keras.saving import register_keras_serializable |
|
|
import tensorflow as tf |
|
|
|
|
|
|
|
|
def small_config(): |
|
|
T = 500 |
|
|
beta = np.linspace(1e-4, 0.02, T) |
|
|
alpha = 1 - beta |
|
|
a = np.cumprod(alpha) |
|
|
|
|
|
return { |
|
|
"filters": [128, 256], |
|
|
"hidden_dim": 384, |
|
|
"heads": 6, |
|
|
"layers": 8, |
|
|
"patch_size": 4, |
|
|
"batch_size": 64, |
|
|
"T": T, |
|
|
"context_size": 8, |
|
|
"image_size": 128, |
|
|
"latent_shape": (32, 32, 4), |
|
|
"beta": beta, |
|
|
"alpha": alpha, |
|
|
"a": a} |
|
|
|
|
|
def med_config(): |
|
|
T = 1000 |
|
|
beta = np.linspace(1e-4, 0.02, T) |
|
|
alpha = 1 - beta |
|
|
a = np.cumprod(alpha) |
|
|
|
|
|
return { |
|
|
"filters": [128, 256], |
|
|
"hidden_dim": 768, |
|
|
"heads": 12, |
|
|
"layers": 12, |
|
|
"patch_size": 4, |
|
|
"batch_size": 64, |
|
|
"T": T, |
|
|
"context_size": 8, |
|
|
"image_size": 128, |
|
|
"latent_shape": (32, 32, 4), |
|
|
"beta": beta, |
|
|
"alpha": alpha, |
|
|
"a": a} |
|
|
|
|
|
def large_config(): |
|
|
T = 1000 |
|
|
beta = np.linspace(1e-4, 0.02, T) |
|
|
alpha = 1 - beta |
|
|
a = np.cumprod(alpha) |
|
|
|
|
|
return { |
|
|
"filters": [128, 256], |
|
|
"hidden_dim": 1024, |
|
|
"heads": 16, |
|
|
"layers": 24, |
|
|
"patch_size": 4, |
|
|
"batch_size": 64, |
|
|
"T": T, |
|
|
"context_size": 8, |
|
|
"image_size": 128, |
|
|
"latent_shape": (32, 32, 4), |
|
|
"beta": beta, |
|
|
"alpha": alpha, |
|
|
"a": a} |
|
|
|
|
|
config = med_config() |
|
|
|
|
|
filters = config['filters'] |
|
|
hidden_dim = config['hidden_dim'] |
|
|
heads = config['heads'] |
|
|
layers = config['layers'] |
|
|
patch_size = config['patch_size'] |
|
|
batch_size = config['batch_size'] |
|
|
T = config['T'] |
|
|
context_size = config['context_size'] |
|
|
image_size = config['image_size'] |
|
|
latent_shape = config['latent_shape'] |
|
|
beta = config['beta'] |
|
|
alpha = config['alpha'] |
|
|
a = config['a'] |
|
|
|
|
|
|
|
|
@register_keras_serializable() |
|
|
class ResBlock(tf.keras.layers.Layer): |
|
|
def __init__(self, filters, p, **kwargs): |
|
|
super(ResBlock, self).__init__(**kwargs) |
|
|
self.filters = filters |
|
|
self.p = p |
|
|
self.reshape = Conv2D(filters, kernel_size=1, strides=1, padding="same") |
|
|
|
|
|
self.conv1 = Conv2D(filters, kernel_size=p, strides=1, padding="same", activation="swish") |
|
|
self.conv2 = Conv2D(filters, kernel_size=p, strides=1, padding="same") |
|
|
|
|
|
def call(self, x): |
|
|
x = self.reshape(x) |
|
|
resid = x |
|
|
|
|
|
resid = self.conv1(resid) |
|
|
resid = self.conv2(resid) |
|
|
x = x + resid |
|
|
return x |
|
|
|
|
|
def get_config(self): |
|
|
config = super().get_config() |
|
|
config.update({ |
|
|
"filters": self.filters, |
|
|
"p": self.p}) |
|
|
return config |
|
|
|
|
|
@register_keras_serializable() |
|
|
class DownBlock(tf.keras.layers.Layer): |
|
|
def __init__(self, filters, **kwargs): |
|
|
super(DownBlock, self).__init__(**kwargs) |
|
|
self.filters = filters |
|
|
self.resBlocks = [ResBlock(f, p=3) for f in filters] |
|
|
self.pool = tf.keras.layers.MaxPool2D(pool_size=(2, 2)) |
|
|
|
|
|
def call(self, x): |
|
|
for resBlock in self.resBlocks: |
|
|
x = resBlock(x) |
|
|
x = self.pool(x) |
|
|
return x |
|
|
|
|
|
def get_config(self): |
|
|
config = super().get_config() |
|
|
config.update({ |
|
|
"filters": self.filters}) |
|
|
return config |
|
|
|
|
|
@register_keras_serializable() |
|
|
class UpBlock(tf.keras.layers.Layer): |
|
|
def __init__(self, filters, **kwargs): |
|
|
super(UpBlock, self).__init__(**kwargs) |
|
|
self.filters = filters |
|
|
self.resBlocks = [ResBlock(f, p=3) for f in filters] |
|
|
self.upSample = UpSampling2D(size=2, interpolation="bilinear") |
|
|
|
|
|
def call(self, x): |
|
|
x = self.upSample(x) |
|
|
for resBlock in self.resBlocks: |
|
|
x = resBlock(x) |
|
|
return x |
|
|
|
|
|
def get_config(self): |
|
|
config = super().get_config() |
|
|
config.update({ |
|
|
"filters": self.filters}) |
|
|
return config |
|
|
|
|
|
|
|
|
@register_keras_serializable() |
|
|
class Encoder(tf.keras.Model): |
|
|
def __init__(self, filters, latent_dim, **kwargs): |
|
|
super(Encoder, self).__init__(**kwargs) |
|
|
self.filters = filters |
|
|
self.latent_dim = latent_dim |
|
|
self.downBlocks = [DownBlock([f,f]) for f in filters] |
|
|
self.latent_proj = Conv2D(latent_dim * 2, kernel_size=1, strides=1, padding="same", activation="linear") |
|
|
|
|
|
@tf.function |
|
|
def sample(self, mu, logvar): |
|
|
eps = tf.random.normal(shape=tf.shape(mu)) |
|
|
return eps * tf.exp(logvar * .5) + mu |
|
|
|
|
|
def call(self, x, training=1): |
|
|
for downBlock in self.downBlocks: |
|
|
x = downBlock(x) |
|
|
x = self.latent_proj(x) |
|
|
mu, logvar = tf.split(x, 2, axis=-1) |
|
|
z = self.sample(mu, logvar) |
|
|
return z, mu, logvar |
|
|
|
|
|
def get_config(self): |
|
|
config = super().get_config() |
|
|
config.update({ |
|
|
"filters": self.filters, |
|
|
"latent_dim": self.latent_dim}) |
|
|
return config |
|
|
|
|
|
def compute_output_shape(self, input_shape): |
|
|
return (input_shape[0], self.latent_dim), (input_shape[0], self.latent_dim), (input_shape[0], self.latent_dim) |
|
|
|
|
|
@register_keras_serializable() |
|
|
class Decoder(tf.keras.Model): |
|
|
def __init__(self, filters, img_size, **kwargs): |
|
|
super(Decoder, self).__init__(**kwargs) |
|
|
self.filters = filters[::-1] |
|
|
self.img_size = img_size |
|
|
self.undo_latent_proj = Conv2D(filters[0], kernel_size=1, strides=1, padding="same") |
|
|
self.upBlocks = [UpBlock([f,f]) for f in filters] |
|
|
self.conv_proj = Conv2D(3, kernel_size=3, padding="same", activation="linear") |
|
|
|
|
|
def call(self, z, training=1): |
|
|
z = self.undo_latent_proj(z) |
|
|
for upBlock in self.upBlocks: |
|
|
z = upBlock(z) |
|
|
x = self.conv_proj(z) |
|
|
return x |
|
|
|
|
|
def get_config(self): |
|
|
config = super().get_config() |
|
|
config.update({ |
|
|
"filters": self.filters[::-1], |
|
|
"img_size": self.img_size}) |
|
|
return config |
|
|
|
|
|
def compute_output_shape(self, input_shape): |
|
|
return (input_shape[0], self.img_size, self.img_size, 3) |
|
|
|
|
|
|
|
|
def process_text(text): |
|
|
tokens = token2vec.encode(text) |
|
|
while len(tokens) < context_size: |
|
|
tokens.append(0) |
|
|
return np.array(tokens[0:context_size]) |
|
|
|
|
|
|
|
|
def normalise_img(img_tensor): |
|
|
img = img_tensor |
|
|
img *= 0.5 |
|
|
img += 0.5 |
|
|
return img |
|
|
|
|
|
def prep_img(img_tensor): |
|
|
img = img_tensor.copy() |
|
|
img = img / 127.5 |
|
|
img -= 1 |
|
|
return img |
|
|
|
|
|
def noisify_img(img_tensor, t, a): |
|
|
epsilon = np.random.normal(0, 1, img_tensor.shape).astype(np.float32) |
|
|
sqrt_alpha_bar = np.sqrt(a[t]) |
|
|
sqrt_one_minus_alpha_bar = np.sqrt(1 - a[t]) |
|
|
x_t = sqrt_alpha_bar * img_tensor + sqrt_one_minus_alpha_bar * epsilon |
|
|
return x_t, epsilon |
|
|
|
|
|
def denoise_step(x_t, eps_hat, t, a, beta): |
|
|
""" |
|
|
Reverse one DDPM step: x_t β x_{t-1} |
|
|
""" |
|
|
a_bar_t = tf.convert_to_tensor(a[t], dtype=tf.float32) |
|
|
a_bar_prev = tf.convert_to_tensor(a[t - 1] if t > 0 else 1.0, dtype=tf.float32) |
|
|
a_t = a_bar_t / a_bar_prev |
|
|
beta_t = tf.convert_to_tensor(beta[t], dtype=tf.float32) |
|
|
|
|
|
|
|
|
sqrt_recip_a_t = tf.math.rsqrt(tf.maximum(a_t, 1e-5)) |
|
|
sqrt_one_minus_ab = tf.sqrt(tf.maximum(1. - a_bar_t, 1e-5)) |
|
|
|
|
|
eps_term = (beta_t / sqrt_one_minus_ab) * eps_hat |
|
|
mean = sqrt_recip_a_t * (x_t - eps_term) |
|
|
|
|
|
if t > 1: |
|
|
noise = tf.random.normal(shape=x_t.shape) |
|
|
sigma = tf.sqrt(tf.maximum(beta_t, 1e-5)) |
|
|
x_prev = mean + sigma * noise |
|
|
else: |
|
|
x_prev = mean |
|
|
|
|
|
return x_prev |
|
|
|
|
|
|
|
|
@register_keras_serializable() |
|
|
class TransformerBlock(tf.keras.Layer): |
|
|
def __init__(self, context_size, head_no, latent_dim, **kwargs): |
|
|
super().__init__(**kwargs) |
|
|
self.context_size = context_size |
|
|
self.head_no = head_no |
|
|
self.latent_dim = latent_dim |
|
|
self.attn = MultiHeadAttention(num_heads=head_no, key_dim=latent_dim//head_no, output_shape=latent_dim) |
|
|
self.mlp_up = Dense(latent_dim*4, activation="gelu") |
|
|
self.mlp_down = Dense(latent_dim) |
|
|
self.norm1 = LayerNormalization() |
|
|
self.norm2 = LayerNormalization() |
|
|
|
|
|
def call(self, x): |
|
|
normed = self.norm1(x) |
|
|
x = x + self.attn(normed, normed, normed) |
|
|
normed = self.norm2(x) |
|
|
dx = self.mlp_up(normed) |
|
|
x = x + self.mlp_down(dx) |
|
|
return x |
|
|
|
|
|
def build(self, input_shape): |
|
|
super().build(input_shape) |
|
|
|
|
|
def compute_output_shape(self, input_shape): |
|
|
return input_shape |
|
|
|
|
|
def get_config(self): |
|
|
config = super().get_config() |
|
|
config.update({ |
|
|
"context_size": self.context_size, |
|
|
"head_no": self.head_no, |
|
|
"latent_dim": self.latent_dim}) |
|
|
return config |
|
|
|
|
|
|
|
|
@register_keras_serializable() |
|
|
class AdaptiveLayerNorm(tf.keras.Layer): |
|
|
def __init__(self, eps=1e-6,**kwargs): |
|
|
self.layernorm = LayerNormalization(epsilon=eps,center=False, scale=False) |
|
|
super(AdaptiveLayerNorm, self).__init__(**kwargs) |
|
|
|
|
|
def build(self, input_shape): |
|
|
|
|
|
self.M = Dense(input_shape[2], use_bias=True, kernel_initializer='glorot_uniform', activation="linear") |
|
|
self.b = Dense(input_shape[2], use_bias=True, kernel_initializer='glorot_uniform', activation="linear") |
|
|
|
|
|
def call(self, x, cond): |
|
|
gamma = self.M(cond) |
|
|
beta = self.b(cond) |
|
|
x = self.layernorm(x) |
|
|
x = x * (1 + tf.expand_dims(gamma, 1)) + tf.expand_dims(beta, 1) |
|
|
return x |
|
|
|
|
|
def get_config(self): |
|
|
config = super().get_config() |
|
|
return config |
|
|
|
|
|
|
|
|
@register_keras_serializable() |
|
|
class ImageEmbedder(tf.keras.Layer): |
|
|
def __init__(self, latent_size, patch_size, emb_dim,**kwargs): |
|
|
super().__init__(**kwargs) |
|
|
self.emb_dim = emb_dim |
|
|
self.patch_size = patch_size |
|
|
self.latent_size = latent_size |
|
|
self.pos_emb = Embedding(input_dim=(latent_size // patch_size)**2 , output_dim=emb_dim, embeddings_initializer="glorot_uniform") |
|
|
self.reshaper = Dense(emb_dim, kernel_initializer="glorot_uniform") |
|
|
self.conv_expansion = Conv2D(emb_dim, kernel_size=patch_size, strides=patch_size, padding="same") |
|
|
|
|
|
def call(self, x): |
|
|
x = self.reshaper(x) |
|
|
x = self.conv_expansion(x) |
|
|
x = tf.reshape(x, shape=[tf.shape(x)[0], tf.shape(x)[1]*tf.shape(x)[2], tf.shape(x)[3]]) |
|
|
positions = tf.range(start=0, limit=(self.latent_size // self.patch_size)**2, delta=1) |
|
|
embeddings = self.pos_emb(positions) |
|
|
x = embeddings + x |
|
|
return x |
|
|
|
|
|
def get_config(self): |
|
|
config = super().get_config() |
|
|
config.update({ |
|
|
"latent_size" : self.latent_size, |
|
|
"patch_size": self.patch_size, |
|
|
"emb_dim": self.emb_dim}) |
|
|
return config |
|
|
|
|
|
@register_keras_serializable() |
|
|
class ImageUnembedder(tf.keras.Layer): |
|
|
def __init__(self, latent_size, patch_size, latent_dim, **kwargs): |
|
|
super().__init__(**kwargs) |
|
|
self.latent_dim = latent_dim |
|
|
self.patch_size = patch_size |
|
|
self.latent_size = latent_size |
|
|
self.AdaLN = AdaptiveLayerNorm() |
|
|
self.reshape_to_latent = Dense(patch_size*patch_size*latent_dim, kernel_initializer="glorot_uniform") |
|
|
|
|
|
def call(self, x, cond): |
|
|
x = self.AdaLN(x, cond) |
|
|
x = self.reshape_to_latent(x) |
|
|
x = tf.reshape(x, shape= |
|
|
[tf.shape(x)[0], |
|
|
self.latent_size // self.patch_size, |
|
|
self.latent_size // self.patch_size, |
|
|
self.latent_dim*(self.patch_size**2)]) |
|
|
x = tf.nn.depth_to_space(x, block_size=self.patch_size) |
|
|
return x |
|
|
|
|
|
def get_config(self): |
|
|
config = super().get_config() |
|
|
config.update({ |
|
|
"latent_size" : self.latent_size, |
|
|
"patch_size": self.patch_size, |
|
|
"latent_dim": self.latent_dim}) |
|
|
return config |
|
|
|
|
|
|
|
|
@register_keras_serializable() |
|
|
class ConditioningEmbedder(tf.keras.layers.Layer): |
|
|
def __init__(self, emb_dim, T, context_size, vocab_size=100266, **kwargs): |
|
|
super().__init__(**kwargs) |
|
|
self.emb_dim = emb_dim |
|
|
self.T = T |
|
|
self.context_size = context_size |
|
|
self.vocab_size = vocab_size |
|
|
positions = tf.range(T, dtype=tf.float32)[:, tf.newaxis] |
|
|
frequencies = tf.constant(10000 ** (-tf.range(0, emb_dim, 2, dtype=tf.float32) / emb_dim)) |
|
|
angle_rates = positions * frequencies |
|
|
sin_part = tf.sin(angle_rates) |
|
|
cos_part = tf.cos(angle_rates) |
|
|
emb = tf.stack([sin_part, cos_part], axis=-1) |
|
|
emb = tf.reshape(emb, [T, emb_dim]) |
|
|
self.t_embeddings = tf.constant(emb, dtype=tf.float32) |
|
|
|
|
|
self.prompt_emb = self.add_weight(shape=(vocab_size, emb_dim), initializer='glorot_uniform', name='prompt_emb', trainable=True) |
|
|
self.CLS = self.add_weight(shape=(emb_dim,), initializer='glorot_uniform', name='CLS', trainable=True) |
|
|
self.prompt_pos_enc = self.add_weight(shape=(1, context_size+1, emb_dim), initializer='glorot_uniform', name='prompt_pos_enc', trainable=True) |
|
|
self.transformer = TransformerBlock(context_size+1, head_no=6, latent_dim=emb_dim) |
|
|
|
|
|
def call(self, x): |
|
|
t, prompt_tokens = x |
|
|
|
|
|
|
|
|
t = tf.cast(tf.squeeze(t, axis=-1), tf.int32) |
|
|
embedded_t = tf.gather(self.t_embeddings, t) |
|
|
embedded_t = embedded_t[:, tf.newaxis, :] |
|
|
|
|
|
|
|
|
embedded_prompt = tf.nn.embedding_lookup( |
|
|
self.prompt_emb, prompt_tokens) |
|
|
|
|
|
cls_tok = tf.tile(self.CLS[None, None, :], |
|
|
[tf.shape(embedded_prompt)[0], 1, 1]) |
|
|
embedded_prompt = tf.concat([cls_tok, embedded_prompt], axis=1) |
|
|
embedded_prompt += self.prompt_pos_enc |
|
|
embedded_prompt = self.transformer(embedded_prompt) |
|
|
|
|
|
|
|
|
embedded_prompt += embedded_t |
|
|
|
|
|
|
|
|
return embedded_prompt[:, 0, :] |
|
|
|
|
|
|
|
|
def get_config(self): |
|
|
config = super().get_config() |
|
|
config.update({ |
|
|
"emb_dim": self.emb_dim, |
|
|
"T": self.T, |
|
|
"context_size": self.context_size, |
|
|
"vocab_size": self.vocab_size}) |
|
|
return config |
|
|
|
|
|
|
|
|
class Gain(tf.keras.layers.Layer): |
|
|
def __init__(self): |
|
|
super(Gain, self).__init__() |
|
|
|
|
|
def build(self, input_shape): |
|
|
self.M = Dense(input_shape[2], use_bias=True,kernel_initializer='glorot_uniform') |
|
|
|
|
|
def call(self, x, cond): |
|
|
scale = self.M(cond) |
|
|
x *= tf.expand_dims(scale, 1) |
|
|
return x |
|
|
|
|
|
@register_keras_serializable() |
|
|
class DiTBlock(tf.keras.layers.Layer): |
|
|
def __init__(self, hidden_dim, heads, context_size, **kwargs): |
|
|
super().__init__(**kwargs) |
|
|
self.emb_dim = hidden_dim |
|
|
self.heads = heads |
|
|
self.context_size = context_size |
|
|
self.gain1 = Gain() |
|
|
self.gain2 = Gain() |
|
|
self.adaLN1 = AdaptiveLayerNorm() |
|
|
|
|
|
self.attn = MultiHeadAttention(num_heads=self.heads, key_dim=self.emb_dim//self.heads, output_shape=self.emb_dim) |
|
|
self.adaLN2 = AdaptiveLayerNorm() |
|
|
self.mlp_up = Dense(self.emb_dim*4, activation="gelu") |
|
|
self.mlp_down = Dense(self.emb_dim) |
|
|
|
|
|
def call(self, x, cond): |
|
|
R = self.adaLN1(x, cond) |
|
|
R = self.gain1(self.attn(R, R, R), cond) |
|
|
x = x + R |
|
|
R = self.adaLN2(x, cond) |
|
|
R = self.mlp_up(R) |
|
|
R = self.gain2(self.mlp_down(R), cond) |
|
|
x = x + R |
|
|
return x |
|
|
|
|
|
def get_config(self): |
|
|
config = super().get_config() |
|
|
config.update({"hidden_dim": self.emb_dim, |
|
|
"heads": self.heads, |
|
|
"context_size": self.context_size}) |
|
|
return config |
|
|
|
|
|
encoder = tf.keras.models.load_model("/repository/encoder.keras") |
|
|
decoder = tf.keras.models.load_model("/repository/decoder.keras") |
|
|
diffuser = tf.keras.models.load_model("/repository/diffusion-med-coco.keras") |
|
|
|
|
|
def inference(prompts): |
|
|
N = len(prompts) |
|
|
x_t = tf.random.normal(shape=(N, 32, 32, 4)) |
|
|
texts = tf.convert_to_tensor([process_text(p) for p in prompts]) |
|
|
t_shape = (N, 1) |
|
|
|
|
|
for t in reversed(range(T)): |
|
|
t_batch = tf.convert_to_tensor([[t]] * N) |
|
|
eps_hat = diffuser([x_t, texts, t_batch]) |
|
|
x_t = tf.convert_to_tensor(denoise_step(x_t.numpy(), eps_hat.numpy(), t, a, beta), dtype=tf.float32) |
|
|
|
|
|
x_0 = x_t.numpy() |
|
|
imgs = decoder(x_0) |
|
|
return imgs |
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, path="."): |
|
|
pass |
|
|
|
|
|
def __call__(self, inputs: Dict[str, Any]) -> Dict[str, Any]: |
|
|
prompts = inputs["inputs"] |
|
|
N = len(prompts) |
|
|
x_t = tf.random.normal(shape=(N, *latent_shape)) |
|
|
texts = tf.convert_to_tensor([process_text(p) for p in prompts]) |
|
|
|
|
|
for t in reversed(range(T)): |
|
|
t_batch = tf.convert_to_tensor([[t]] * N) |
|
|
eps_hat = diffuser([x_t, texts, t_batch]) |
|
|
x_t = tf.convert_to_tensor( |
|
|
denoise_step(x_t.numpy(), eps_hat.numpy(), t, a, beta), dtype=tf.float32 |
|
|
) |
|
|
|
|
|
imgs = decoder(x_t) |
|
|
return {"outputs": imgs.numpy().tolist()} |