from transformers import PreTrainedTokenizerFast token2vec = PreTrainedTokenizerFast.from_pretrained("/repository/bpe") from typing import Dict, Any import numpy as np import tensorflow as tf from tensorflow.keras.layers import Dense, LayerNormalization, Conv2D, UpSampling2D, Embedding, MultiHeadAttention from tensorflow.keras.saving import register_keras_serializable import tensorflow as tf # @title Config def small_config(): T = 500 beta = np.linspace(1e-4, 0.02, T) alpha = 1 - beta a = np.cumprod(alpha) return { "filters": [128, 256], "hidden_dim": 384, "heads": 6, "layers": 8, "patch_size": 4, "batch_size": 64, "T": T, "context_size": 8, "image_size": 128, "latent_shape": (32, 32, 4), "beta": beta, "alpha": alpha, "a": a} def med_config(): T = 1000 beta = np.linspace(1e-4, 0.02, T) alpha = 1 - beta a = np.cumprod(alpha) return { "filters": [128, 256], "hidden_dim": 768, "heads": 12, "layers": 12, "patch_size": 4, "batch_size": 64, "T": T, "context_size": 8, "image_size": 128, "latent_shape": (32, 32, 4), "beta": beta, "alpha": alpha, "a": a} def large_config(): T = 1000 beta = np.linspace(1e-4, 0.02, T) alpha = 1 - beta a = np.cumprod(alpha) return { "filters": [128, 256], "hidden_dim": 1024, "heads": 16, "layers": 24, "patch_size": 4, "batch_size": 64, "T": T, "context_size": 8, "image_size": 128, "latent_shape": (32, 32, 4), "beta": beta, "alpha": alpha, "a": a} config = med_config() filters = config['filters'] hidden_dim = config['hidden_dim'] heads = config['heads'] layers = config['layers'] patch_size = config['patch_size'] batch_size = config['batch_size'] T = config['T'] context_size = config['context_size'] image_size = config['image_size'] latent_shape = config['latent_shape'] beta = config['beta'] alpha = config['alpha'] a = config['a'] # @title ResBlock, UpBlock, DownBlock @register_keras_serializable() class ResBlock(tf.keras.layers.Layer): def __init__(self, filters, p, **kwargs): super(ResBlock, self).__init__(**kwargs) self.filters = filters self.p = p self.reshape = Conv2D(filters, kernel_size=1, strides=1, padding="same") #self.norm = BatchNormalization(center=False, scale=False) self.conv1 = Conv2D(filters, kernel_size=p, strides=1, padding="same", activation="swish") self.conv2 = Conv2D(filters, kernel_size=p, strides=1, padding="same") def call(self, x): x = self.reshape(x) resid = x #resid = self.norm(resid) resid = self.conv1(resid) resid = self.conv2(resid) x = x + resid return x def get_config(self): config = super().get_config() config.update({ "filters": self.filters, "p": self.p}) return config @register_keras_serializable() class DownBlock(tf.keras.layers.Layer): def __init__(self, filters, **kwargs): super(DownBlock, self).__init__(**kwargs) self.filters = filters self.resBlocks = [ResBlock(f, p=3) for f in filters] self.pool = tf.keras.layers.MaxPool2D(pool_size=(2, 2)) def call(self, x): for resBlock in self.resBlocks: x = resBlock(x) x = self.pool(x) return x def get_config(self): config = super().get_config() config.update({ "filters": self.filters}) return config @register_keras_serializable() class UpBlock(tf.keras.layers.Layer): def __init__(self, filters, **kwargs): super(UpBlock, self).__init__(**kwargs) self.filters = filters self.resBlocks = [ResBlock(f, p=3) for f in filters] self.upSample = UpSampling2D(size=2, interpolation="bilinear") def call(self, x): x = self.upSample(x) for resBlock in self.resBlocks: x = resBlock(x) return x def get_config(self): config = super().get_config() config.update({ "filters": self.filters}) return config # @title Encoder, Decoder @register_keras_serializable() class Encoder(tf.keras.Model): def __init__(self, filters, latent_dim, **kwargs): super(Encoder, self).__init__(**kwargs) self.filters = filters self.latent_dim = latent_dim self.downBlocks = [DownBlock([f,f]) for f in filters] self.latent_proj = Conv2D(latent_dim * 2, kernel_size=1, strides=1, padding="same", activation="linear") @tf.function def sample(self, mu, logvar): eps = tf.random.normal(shape=tf.shape(mu)) return eps * tf.exp(logvar * .5) + mu def call(self, x, training=1): for downBlock in self.downBlocks: x = downBlock(x) x = self.latent_proj(x) mu, logvar = tf.split(x, 2, axis=-1) z = self.sample(mu, logvar) return z, mu, logvar def get_config(self): config = super().get_config() config.update({ "filters": self.filters, "latent_dim": self.latent_dim}) return config def compute_output_shape(self, input_shape): return (input_shape[0], self.latent_dim), (input_shape[0], self.latent_dim), (input_shape[0], self.latent_dim) @register_keras_serializable() class Decoder(tf.keras.Model): def __init__(self, filters, img_size, **kwargs): super(Decoder, self).__init__(**kwargs) self.filters = filters[::-1] self.img_size = img_size self.undo_latent_proj = Conv2D(filters[0], kernel_size=1, strides=1, padding="same") self.upBlocks = [UpBlock([f,f]) for f in filters] self.conv_proj = Conv2D(3, kernel_size=3, padding="same", activation="linear") def call(self, z, training=1): z = self.undo_latent_proj(z) for upBlock in self.upBlocks: z = upBlock(z) x = self.conv_proj(z) return x def get_config(self): config = super().get_config() config.update({ "filters": self.filters[::-1], "img_size": self.img_size}) return config def compute_output_shape(self, input_shape): return (input_shape[0], self.img_size, self.img_size, 3) # @title Helper Functions def process_text(text): tokens = token2vec.encode(text) while len(tokens) < context_size: tokens.append(0) return np.array(tokens[0:context_size]) def normalise_img(img_tensor): # Maps [-1,1] to [0,1] img = img_tensor img *= 0.5 img += 0.5 return img def prep_img(img_tensor): # Maps [0,255] to [-1,1] img = img_tensor.copy() img = img / 127.5 img -= 1 return img def noisify_img(img_tensor, t, a): # Returns x_t and the noise used epsilon = np.random.normal(0, 1, img_tensor.shape).astype(np.float32) # Standard normal sqrt_alpha_bar = np.sqrt(a[t]) sqrt_one_minus_alpha_bar = np.sqrt(1 - a[t]) x_t = sqrt_alpha_bar * img_tensor + sqrt_one_minus_alpha_bar * epsilon return x_t, epsilon def denoise_step(x_t, eps_hat, t, a, beta): """ Reverse one DDPM step: x_t → x_{t-1} """ a_bar_t = tf.convert_to_tensor(a[t], dtype=tf.float32) a_bar_prev = tf.convert_to_tensor(a[t - 1] if t > 0 else 1.0, dtype=tf.float32) a_t = a_bar_t / a_bar_prev beta_t = tf.convert_to_tensor(beta[t], dtype=tf.float32) # Avoid NaNs with clamping sqrt_recip_a_t = tf.math.rsqrt(tf.maximum(a_t, 1e-5)) sqrt_one_minus_ab = tf.sqrt(tf.maximum(1. - a_bar_t, 1e-5)) eps_term = (beta_t / sqrt_one_minus_ab) * eps_hat mean = sqrt_recip_a_t * (x_t - eps_term) if t > 1: noise = tf.random.normal(shape=x_t.shape) sigma = tf.sqrt(tf.maximum(beta_t, 1e-5)) x_prev = mean + sigma * noise else: x_prev = mean return x_prev # @title Transformer Block @register_keras_serializable() class TransformerBlock(tf.keras.Layer): def __init__(self, context_size, head_no, latent_dim, **kwargs): super().__init__(**kwargs) self.context_size = context_size self.head_no = head_no self.latent_dim = latent_dim self.attn = MultiHeadAttention(num_heads=head_no, key_dim=latent_dim//head_no, output_shape=latent_dim) self.mlp_up = Dense(latent_dim*4, activation="gelu") self.mlp_down = Dense(latent_dim) self.norm1 = LayerNormalization() self.norm2 = LayerNormalization() def call(self, x): normed = self.norm1(x) x = x + self.attn(normed, normed, normed) normed = self.norm2(x) dx = self.mlp_up(normed) x = x + self.mlp_down(dx) return x def build(self, input_shape): super().build(input_shape) def compute_output_shape(self, input_shape): return input_shape def get_config(self): config = super().get_config() config.update({ "context_size": self.context_size, "head_no": self.head_no, "latent_dim": self.latent_dim}) return config # @title AdaLN-Zero @register_keras_serializable() class AdaptiveLayerNorm(tf.keras.Layer): def __init__(self, eps=1e-6,**kwargs): self.layernorm = LayerNormalization(epsilon=eps,center=False, scale=False) super(AdaptiveLayerNorm, self).__init__(**kwargs) def build(self, input_shape): #B, num_patches, hidden_dim self.M = Dense(input_shape[2], use_bias=True, kernel_initializer='glorot_uniform', activation="linear") self.b = Dense(input_shape[2], use_bias=True, kernel_initializer='glorot_uniform', activation="linear") def call(self, x, cond): gamma = self.M(cond) beta = self.b(cond) x = self.layernorm(x) x = x * (1 + tf.expand_dims(gamma, 1)) + tf.expand_dims(beta, 1) return x def get_config(self): config = super().get_config() return config # @title Image Embedder, Unembedder @register_keras_serializable() class ImageEmbedder(tf.keras.Layer): def __init__(self, latent_size, patch_size, emb_dim,**kwargs): super().__init__(**kwargs) self.emb_dim = emb_dim self.patch_size = patch_size self.latent_size = latent_size self.pos_emb = Embedding(input_dim=(latent_size // patch_size)**2 , output_dim=emb_dim, embeddings_initializer="glorot_uniform") self.reshaper = Dense(emb_dim, kernel_initializer="glorot_uniform") self.conv_expansion = Conv2D(emb_dim, kernel_size=patch_size, strides=patch_size, padding="same") def call(self, x): x = self.reshaper(x) x = self.conv_expansion(x) x = tf.reshape(x, shape=[tf.shape(x)[0], tf.shape(x)[1]*tf.shape(x)[2], tf.shape(x)[3]]) positions = tf.range(start=0, limit=(self.latent_size // self.patch_size)**2, delta=1) embeddings = self.pos_emb(positions) x = embeddings + x return x def get_config(self): config = super().get_config() config.update({ "latent_size" : self.latent_size, "patch_size": self.patch_size, "emb_dim": self.emb_dim}) return config @register_keras_serializable() class ImageUnembedder(tf.keras.Layer): def __init__(self, latent_size, patch_size, latent_dim, **kwargs): super().__init__(**kwargs) self.latent_dim = latent_dim self.patch_size = patch_size self.latent_size = latent_size self.AdaLN = AdaptiveLayerNorm() self.reshape_to_latent = Dense(patch_size*patch_size*latent_dim, kernel_initializer="glorot_uniform") def call(self, x, cond): x = self.AdaLN(x, cond) x = self.reshape_to_latent(x) x = tf.reshape(x, shape= [tf.shape(x)[0], self.latent_size // self.patch_size, self.latent_size // self.patch_size, self.latent_dim*(self.patch_size**2)]) x = tf.nn.depth_to_space(x, block_size=self.patch_size) return x def get_config(self): config = super().get_config() config.update({ "latent_size" : self.latent_size, "patch_size": self.patch_size, "latent_dim": self.latent_dim}) return config # @title LEGACY Prompt and Timestep Embedder @register_keras_serializable() class ConditioningEmbedder(tf.keras.layers.Layer): def __init__(self, emb_dim, T, context_size, vocab_size=100266, **kwargs): super().__init__(**kwargs) self.emb_dim = emb_dim self.T = T self.context_size = context_size self.vocab_size = vocab_size positions = tf.range(T, dtype=tf.float32)[:, tf.newaxis] frequencies = tf.constant(10000 ** (-tf.range(0, emb_dim, 2, dtype=tf.float32) / emb_dim)) angle_rates = positions * frequencies # (T, emb_dim/2) sin_part = tf.sin(angle_rates) cos_part = tf.cos(angle_rates) emb = tf.stack([sin_part, cos_part], axis=-1) # (T, emb_dim/2, 2) emb = tf.reshape(emb, [T, emb_dim]) # (T, emb_dim) self.t_embeddings = tf.constant(emb, dtype=tf.float32) self.prompt_emb = self.add_weight(shape=(vocab_size, emb_dim), initializer='glorot_uniform', name='prompt_emb', trainable=True) self.CLS = self.add_weight(shape=(emb_dim,), initializer='glorot_uniform', name='CLS', trainable=True) self.prompt_pos_enc = self.add_weight(shape=(1, context_size+1, emb_dim), initializer='glorot_uniform', name='prompt_pos_enc', trainable=True) self.transformer = TransformerBlock(context_size+1, head_no=6, latent_dim=emb_dim) def call(self, x): t, prompt_tokens = x # ── timestep embedding ─────────────────────────── t = tf.cast(tf.squeeze(t, axis=-1), tf.int32) # (batch,) embedded_t = tf.gather(self.t_embeddings, t) # (batch, emb_dim) embedded_t = embedded_t[:, tf.newaxis, :] # (batch, 1, emb_dim) # ── prompt embedding path ───────────────────────── embedded_prompt = tf.nn.embedding_lookup( self.prompt_emb, prompt_tokens) # (batch, seq_len, emb_dim) cls_tok = tf.tile(self.CLS[None, None, :], [tf.shape(embedded_prompt)[0], 1, 1]) embedded_prompt = tf.concat([cls_tok, embedded_prompt], axis=1) embedded_prompt += self.prompt_pos_enc embedded_prompt = self.transformer(embedded_prompt) # (batch, seq_len+1, emb_dim) # add t-embedding to every token (broadcasts along axis-1) embedded_prompt += embedded_t # return CLS (keep singleton axis if you need it) return embedded_prompt[:, 0, :] # (batch, 1, emb_dim) def get_config(self): config = super().get_config() config.update({ "emb_dim": self.emb_dim, "T": self.T, "context_size": self.context_size, "vocab_size": self.vocab_size}) return config # @title DiT Block class Gain(tf.keras.layers.Layer): def __init__(self): super(Gain, self).__init__() def build(self, input_shape): self.M = Dense(input_shape[2], use_bias=True,kernel_initializer='glorot_uniform') def call(self, x, cond): scale = self.M(cond) x *= tf.expand_dims(scale, 1) return x @register_keras_serializable() class DiTBlock(tf.keras.layers.Layer): def __init__(self, hidden_dim, heads, context_size, **kwargs): super().__init__(**kwargs) self.emb_dim = hidden_dim self.heads = heads self.context_size = context_size self.gain1 = Gain() self.gain2 = Gain() self.adaLN1 = AdaptiveLayerNorm() self.attn = MultiHeadAttention(num_heads=self.heads, key_dim=self.emb_dim//self.heads, output_shape=self.emb_dim) self.adaLN2 = AdaptiveLayerNorm() self.mlp_up = Dense(self.emb_dim*4, activation="gelu") self.mlp_down = Dense(self.emb_dim) def call(self, x, cond): R = self.adaLN1(x, cond) R = self.gain1(self.attn(R, R, R), cond) x = x + R R = self.adaLN2(x, cond) R = self.mlp_up(R) R = self.gain2(self.mlp_down(R), cond) x = x + R return x def get_config(self): config = super().get_config() config.update({"hidden_dim": self.emb_dim, "heads": self.heads, "context_size": self.context_size}) return config encoder = tf.keras.models.load_model("/repository/encoder.keras") decoder = tf.keras.models.load_model("/repository/decoder.keras") diffuser = tf.keras.models.load_model("/repository/diffusion-med-coco.keras") def inference(prompts): N = len(prompts) x_t = tf.random.normal(shape=(N, 32, 32, 4)) texts = tf.convert_to_tensor([process_text(p) for p in prompts]) t_shape = (N, 1) for t in reversed(range(T)): t_batch = tf.convert_to_tensor([[t]] * N) eps_hat = diffuser([x_t, texts, t_batch]) x_t = tf.convert_to_tensor(denoise_step(x_t.numpy(), eps_hat.numpy(), t, a, beta), dtype=tf.float32) x_0 = x_t.numpy() imgs = decoder(x_0) return imgs class EndpointHandler: def __init__(self, path="."): pass # models already loaded above def __call__(self, inputs: Dict[str, Any]) -> Dict[str, Any]: prompts = inputs["inputs"] N = len(prompts) x_t = tf.random.normal(shape=(N, *latent_shape)) texts = tf.convert_to_tensor([process_text(p) for p in prompts]) for t in reversed(range(T)): t_batch = tf.convert_to_tensor([[t]] * N) eps_hat = diffuser([x_t, texts, t_batch]) x_t = tf.convert_to_tensor( denoise_step(x_t.numpy(), eps_hat.numpy(), t, a, beta), dtype=tf.float32 ) imgs = decoder(x_t) return {"outputs": imgs.numpy().tolist()}