detectivejoewest commited on
Commit
4742deb
Β·
verified Β·
1 Parent(s): 08ea9dc

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +529 -531
handler.py CHANGED
@@ -1,532 +1,530 @@
1
- from transformers import PreTrainedTokenizerFast
2
-
3
- token2vec = PreTrainedTokenizerFast.from_pretrained("cl100k_tokenizer")
4
-
5
- from typing import Dict, Any
6
- import numpy as np
7
- import tensorflow as tf
8
- from tensorflow.keras.layers import Dense, LayerNormalization, Conv2D, UpSampling2D, Embedding, MultiHeadAttention
9
- from tensorflow.keras.saving import register_keras_serializable
10
- import tensorflow as tf
11
-
12
- token2vec = tiktoken.encoding_for_model("gpt-3.5-turbo")
13
-
14
- # @title Config
15
- def small_config():
16
- T = 500
17
- beta = np.linspace(1e-4, 0.02, T)
18
- alpha = 1 - beta
19
- a = np.cumprod(alpha)
20
-
21
- return {
22
- "filters": [128, 256],
23
- "hidden_dim": 384,
24
- "heads": 6,
25
- "layers": 8,
26
- "patch_size": 4,
27
- "batch_size": 64,
28
- "T": T,
29
- "context_size": 8,
30
- "image_size": 128,
31
- "latent_shape": (32, 32, 4),
32
- "beta": beta,
33
- "alpha": alpha,
34
- "a": a}
35
-
36
- def med_config():
37
- T = 1000
38
- beta = np.linspace(1e-4, 0.02, T)
39
- alpha = 1 - beta
40
- a = np.cumprod(alpha)
41
-
42
- return {
43
- "filters": [128, 256],
44
- "hidden_dim": 768,
45
- "heads": 12,
46
- "layers": 12,
47
- "patch_size": 4,
48
- "batch_size": 64,
49
- "T": T,
50
- "context_size": 8,
51
- "image_size": 128,
52
- "latent_shape": (32, 32, 4),
53
- "beta": beta,
54
- "alpha": alpha,
55
- "a": a}
56
-
57
- def large_config():
58
- T = 1000
59
- beta = np.linspace(1e-4, 0.02, T)
60
- alpha = 1 - beta
61
- a = np.cumprod(alpha)
62
-
63
- return {
64
- "filters": [128, 256],
65
- "hidden_dim": 1024,
66
- "heads": 16,
67
- "layers": 24,
68
- "patch_size": 4,
69
- "batch_size": 64,
70
- "T": T,
71
- "context_size": 8,
72
- "image_size": 128,
73
- "latent_shape": (32, 32, 4),
74
- "beta": beta,
75
- "alpha": alpha,
76
- "a": a}
77
-
78
- config = med_config()
79
-
80
- filters = config['filters']
81
- hidden_dim = config['hidden_dim']
82
- heads = config['heads']
83
- layers = config['layers']
84
- patch_size = config['patch_size']
85
- batch_size = config['batch_size']
86
- T = config['T']
87
- context_size = config['context_size']
88
- image_size = config['image_size']
89
- latent_shape = config['latent_shape']
90
- beta = config['beta']
91
- alpha = config['alpha']
92
- a = config['a']
93
-
94
- # @title ResBlock, UpBlock, DownBlock
95
- @register_keras_serializable()
96
- class ResBlock(tf.keras.layers.Layer):
97
- def __init__(self, filters, p, **kwargs):
98
- super(ResBlock, self).__init__(**kwargs)
99
- self.filters = filters
100
- self.p = p
101
- self.reshape = Conv2D(filters, kernel_size=1, strides=1, padding="same")
102
- #self.norm = BatchNormalization(center=False, scale=False)
103
- self.conv1 = Conv2D(filters, kernel_size=p, strides=1, padding="same", activation="swish")
104
- self.conv2 = Conv2D(filters, kernel_size=p, strides=1, padding="same")
105
-
106
- def call(self, x):
107
- x = self.reshape(x)
108
- resid = x
109
- #resid = self.norm(resid)
110
- resid = self.conv1(resid)
111
- resid = self.conv2(resid)
112
- x = x + resid
113
- return x
114
-
115
- def get_config(self):
116
- config = super().get_config()
117
- config.update({
118
- "filters": self.filters,
119
- "p": self.p})
120
- return config
121
-
122
- @register_keras_serializable()
123
- class DownBlock(tf.keras.layers.Layer):
124
- def __init__(self, filters, **kwargs):
125
- super(DownBlock, self).__init__(**kwargs)
126
- self.filters = filters
127
- self.resBlocks = [ResBlock(f, p=3) for f in filters]
128
- self.pool = tf.keras.layers.MaxPool2D(pool_size=(2, 2))
129
-
130
- def call(self, x):
131
- for resBlock in self.resBlocks:
132
- x = resBlock(x)
133
- x = self.pool(x)
134
- return x
135
-
136
- def get_config(self):
137
- config = super().get_config()
138
- config.update({
139
- "filters": self.filters})
140
- return config
141
-
142
- @register_keras_serializable()
143
- class UpBlock(tf.keras.layers.Layer):
144
- def __init__(self, filters, **kwargs):
145
- super(UpBlock, self).__init__(**kwargs)
146
- self.filters = filters
147
- self.resBlocks = [ResBlock(f, p=3) for f in filters]
148
- self.upSample = UpSampling2D(size=2, interpolation="bilinear")
149
-
150
- def call(self, x):
151
- x = self.upSample(x)
152
- for resBlock in self.resBlocks:
153
- x = resBlock(x)
154
- return x
155
-
156
- def get_config(self):
157
- config = super().get_config()
158
- config.update({
159
- "filters": self.filters})
160
- return config
161
-
162
- # @title Encoder, Decoder
163
- @register_keras_serializable()
164
- class Encoder(tf.keras.Model):
165
- def __init__(self, filters, latent_dim, **kwargs):
166
- super(Encoder, self).__init__(**kwargs)
167
- self.filters = filters
168
- self.latent_dim = latent_dim
169
- self.downBlocks = [DownBlock([f,f]) for f in filters]
170
- self.latent_proj = Conv2D(latent_dim * 2, kernel_size=1, strides=1, padding="same", activation="linear")
171
-
172
- @tf.function
173
- def sample(self, mu, logvar):
174
- eps = tf.random.normal(shape=tf.shape(mu))
175
- return eps * tf.exp(logvar * .5) + mu
176
-
177
- def call(self, x, training=1):
178
- for downBlock in self.downBlocks:
179
- x = downBlock(x)
180
- x = self.latent_proj(x)
181
- mu, logvar = tf.split(x, 2, axis=-1)
182
- z = self.sample(mu, logvar)
183
- return z, mu, logvar
184
-
185
- def get_config(self):
186
- config = super().get_config()
187
- config.update({
188
- "filters": self.filters,
189
- "latent_dim": self.latent_dim})
190
- return config
191
-
192
- def compute_output_shape(self, input_shape):
193
- return (input_shape[0], self.latent_dim), (input_shape[0], self.latent_dim), (input_shape[0], self.latent_dim)
194
-
195
- @register_keras_serializable()
196
- class Decoder(tf.keras.Model):
197
- def __init__(self, filters, img_size, **kwargs):
198
- super(Decoder, self).__init__(**kwargs)
199
- self.filters = filters[::-1]
200
- self.img_size = img_size
201
- self.undo_latent_proj = Conv2D(filters[0], kernel_size=1, strides=1, padding="same")
202
- self.upBlocks = [UpBlock([f,f]) for f in filters]
203
- self.conv_proj = Conv2D(3, kernel_size=3, padding="same", activation="linear")
204
-
205
- def call(self, z, training=1):
206
- z = self.undo_latent_proj(z)
207
- for upBlock in self.upBlocks:
208
- z = upBlock(z)
209
- x = self.conv_proj(z)
210
- return x
211
-
212
- def get_config(self):
213
- config = super().get_config()
214
- config.update({
215
- "filters": self.filters[::-1],
216
- "img_size": self.img_size})
217
- return config
218
-
219
- def compute_output_shape(self, input_shape):
220
- return (input_shape[0], self.img_size, self.img_size, 3)
221
-
222
- # @title Helper Functions
223
- def process_text(text):
224
- import tiktoken
225
- tokenizer = tiktoken.encoding_for_model("gpt-3.5-turbo")
226
- tokens = tokenizer.encode(text)
227
- while len(tokens) < context_size:
228
- tokens.append(0)
229
- return tokens[:context_size]
230
-
231
- def normalise_img(img_tensor): # Maps [-1,1] to [0,1]
232
- img = img_tensor
233
- img *= 0.5
234
- img += 0.5
235
- return img
236
-
237
- def prep_img(img_tensor): # Maps [0,255] to [-1,1]
238
- img = img_tensor.copy()
239
- img = img / 127.5
240
- img -= 1
241
- return img
242
-
243
- def noisify_img(img_tensor, t, a): # Returns x_t and the noise used
244
- epsilon = np.random.normal(0, 1, img_tensor.shape).astype(np.float32) # Standard normal
245
- sqrt_alpha_bar = np.sqrt(a[t])
246
- sqrt_one_minus_alpha_bar = np.sqrt(1 - a[t])
247
- x_t = sqrt_alpha_bar * img_tensor + sqrt_one_minus_alpha_bar * epsilon
248
- return x_t, epsilon
249
-
250
- def denoise_step(x_t, eps_hat, t, a, beta):
251
- """
252
- Reverse one DDPM step: x_t β†’ x_{t-1}
253
- """
254
- a_bar_t = tf.convert_to_tensor(a[t], dtype=tf.float32)
255
- a_bar_prev = tf.convert_to_tensor(a[t - 1] if t > 0 else 1.0, dtype=tf.float32)
256
- a_t = a_bar_t / a_bar_prev
257
- beta_t = tf.convert_to_tensor(beta[t], dtype=tf.float32)
258
-
259
- # Avoid NaNs with clamping
260
- sqrt_recip_a_t = tf.math.rsqrt(tf.maximum(a_t, 1e-5))
261
- sqrt_one_minus_ab = tf.sqrt(tf.maximum(1. - a_bar_t, 1e-5))
262
-
263
- eps_term = (beta_t / sqrt_one_minus_ab) * eps_hat
264
- mean = sqrt_recip_a_t * (x_t - eps_term)
265
-
266
- if t > 1:
267
- noise = tf.random.normal(shape=x_t.shape)
268
- sigma = tf.sqrt(tf.maximum(beta_t, 1e-5))
269
- x_prev = mean + sigma * noise
270
- else:
271
- x_prev = mean
272
-
273
- return x_prev
274
-
275
- # @title Transformer Block
276
- @register_keras_serializable()
277
- class TransformerBlock(tf.keras.Layer):
278
- def __init__(self, context_size, head_no, latent_dim, **kwargs):
279
- super().__init__(**kwargs)
280
- self.context_size = context_size
281
- self.head_no = head_no
282
- self.latent_dim = latent_dim
283
- self.attn = MultiHeadAttention(num_heads=head_no, key_dim=latent_dim//head_no, output_shape=latent_dim)
284
- self.mlp_up = Dense(latent_dim*4, activation="gelu")
285
- self.mlp_down = Dense(latent_dim)
286
- self.norm1 = LayerNormalization()
287
- self.norm2 = LayerNormalization()
288
-
289
- def call(self, x):
290
- normed = self.norm1(x)
291
- x = x + self.attn(normed, normed, normed)
292
- normed = self.norm2(x)
293
- dx = self.mlp_up(normed)
294
- x = x + self.mlp_down(dx)
295
- return x
296
-
297
- def build(self, input_shape):
298
- super().build(input_shape)
299
-
300
- def compute_output_shape(self, input_shape):
301
- return input_shape
302
-
303
- def get_config(self):
304
- config = super().get_config()
305
- config.update({
306
- "context_size": self.context_size,
307
- "head_no": self.head_no,
308
- "latent_dim": self.latent_dim})
309
- return config
310
-
311
- # @title AdaLN-Zero
312
- @register_keras_serializable()
313
- class AdaptiveLayerNorm(tf.keras.Layer):
314
- def __init__(self, eps=1e-6,**kwargs):
315
- self.layernorm = LayerNormalization(epsilon=eps,center=False, scale=False)
316
- super(AdaptiveLayerNorm, self).__init__(**kwargs)
317
-
318
- def build(self, input_shape):
319
- #B, num_patches, hidden_dim
320
- self.M = Dense(input_shape[2], use_bias=True, kernel_initializer='glorot_uniform', activation="linear")
321
- self.b = Dense(input_shape[2], use_bias=True, kernel_initializer='glorot_uniform', activation="linear")
322
-
323
- def call(self, x, cond):
324
- gamma = self.M(cond)
325
- beta = self.b(cond)
326
- x = self.layernorm(x)
327
- x = x * (1 + tf.expand_dims(gamma, 1)) + tf.expand_dims(beta, 1)
328
- return x
329
-
330
- def get_config(self):
331
- config = super().get_config()
332
- return config
333
-
334
- # @title Image Embedder, Unembedder
335
- @register_keras_serializable()
336
- class ImageEmbedder(tf.keras.Layer):
337
- def __init__(self, latent_size, patch_size, emb_dim,**kwargs):
338
- super().__init__(**kwargs)
339
- self.emb_dim = emb_dim
340
- self.patch_size = patch_size
341
- self.latent_size = latent_size
342
- self.pos_emb = Embedding(input_dim=(latent_size // patch_size)**2 , output_dim=emb_dim, embeddings_initializer="glorot_uniform")
343
- self.reshaper = Dense(emb_dim, kernel_initializer="glorot_uniform")
344
- self.conv_expansion = Conv2D(emb_dim, kernel_size=patch_size, strides=patch_size, padding="same")
345
-
346
- def call(self, x):
347
- x = self.reshaper(x)
348
- x = self.conv_expansion(x)
349
- x = tf.reshape(x, shape=[tf.shape(x)[0], tf.shape(x)[1]*tf.shape(x)[2], tf.shape(x)[3]])
350
- positions = tf.range(start=0, limit=(self.latent_size // self.patch_size)**2, delta=1)
351
- embeddings = self.pos_emb(positions)
352
- x = embeddings + x
353
- return x
354
-
355
- def get_config(self):
356
- config = super().get_config()
357
- config.update({
358
- "latent_size" : self.latent_size,
359
- "patch_size": self.patch_size,
360
- "emb_dim": self.emb_dim})
361
- return config
362
-
363
- @register_keras_serializable()
364
- class ImageUnembedder(tf.keras.Layer):
365
- def __init__(self, latent_size, patch_size, latent_dim, **kwargs):
366
- super().__init__(**kwargs)
367
- self.latent_dim = latent_dim
368
- self.patch_size = patch_size
369
- self.latent_size = latent_size
370
- self.AdaLN = AdaptiveLayerNorm()
371
- self.reshape_to_latent = Dense(patch_size*patch_size*latent_dim, kernel_initializer="glorot_uniform")
372
-
373
- def call(self, x, cond):
374
- x = self.AdaLN(x, cond)
375
- x = self.reshape_to_latent(x)
376
- x = tf.reshape(x, shape=
377
- [tf.shape(x)[0],
378
- self.latent_size // self.patch_size,
379
- self.latent_size // self.patch_size,
380
- self.latent_dim*(self.patch_size**2)])
381
- x = tf.nn.depth_to_space(x, block_size=self.patch_size)
382
- return x
383
-
384
- def get_config(self):
385
- config = super().get_config()
386
- config.update({
387
- "latent_size" : self.latent_size,
388
- "patch_size": self.patch_size,
389
- "latent_dim": self.latent_dim})
390
- return config
391
-
392
- # @title LEGACY Prompt and Timestep Embedder
393
- @register_keras_serializable()
394
- class ConditioningEmbedder(tf.keras.layers.Layer):
395
- def __init__(self, emb_dim, T, context_size, vocab_size=100266, **kwargs):
396
- super().__init__(**kwargs)
397
- self.emb_dim = emb_dim
398
- self.T = T
399
- self.context_size = context_size
400
- self.vocab_size = vocab_size
401
- positions = tf.range(T, dtype=tf.float32)[:, tf.newaxis]
402
- frequencies = tf.constant(10000 ** (-tf.range(0, emb_dim, 2, dtype=tf.float32) / emb_dim))
403
- angle_rates = positions * frequencies # (T, emb_dim/2)
404
- sin_part = tf.sin(angle_rates)
405
- cos_part = tf.cos(angle_rates)
406
- emb = tf.stack([sin_part, cos_part], axis=-1) # (T, emb_dim/2, 2)
407
- emb = tf.reshape(emb, [T, emb_dim]) # (T, emb_dim)
408
- self.t_embeddings = tf.constant(emb, dtype=tf.float32)
409
-
410
- self.prompt_emb = self.add_weight(shape=(vocab_size, emb_dim), initializer='glorot_uniform', name='prompt_emb', trainable=True)
411
- self.CLS = self.add_weight(shape=(emb_dim,), initializer='glorot_uniform', name='CLS', trainable=True)
412
- self.prompt_pos_enc = self.add_weight(shape=(1, context_size+1, emb_dim), initializer='glorot_uniform', name='prompt_pos_enc', trainable=True)
413
- self.transformer = TransformerBlock(context_size+1, head_no=6, latent_dim=emb_dim)
414
-
415
- def call(self, x):
416
- t, prompt_tokens = x
417
-
418
- # ── timestep embedding ───────────────────────────
419
- t = tf.cast(tf.squeeze(t, axis=-1), tf.int32) # (batch,)
420
- embedded_t = tf.gather(self.t_embeddings, t) # (batch, emb_dim)
421
- embedded_t = embedded_t[:, tf.newaxis, :] # (batch, 1, emb_dim)
422
-
423
- # ── prompt embedding path ─────────────────────────
424
- embedded_prompt = tf.nn.embedding_lookup(
425
- self.prompt_emb, prompt_tokens) # (batch, seq_len, emb_dim)
426
-
427
- cls_tok = tf.tile(self.CLS[None, None, :],
428
- [tf.shape(embedded_prompt)[0], 1, 1])
429
- embedded_prompt = tf.concat([cls_tok, embedded_prompt], axis=1)
430
- embedded_prompt += self.prompt_pos_enc
431
- embedded_prompt = self.transformer(embedded_prompt) # (batch, seq_len+1, emb_dim)
432
-
433
- # add t-embedding to every token (broadcasts along axis-1)
434
- embedded_prompt += embedded_t
435
-
436
- # return CLS (keep singleton axis if you need it)
437
- return embedded_prompt[:, 0, :] # (batch, 1, emb_dim)
438
-
439
-
440
- def get_config(self):
441
- config = super().get_config()
442
- config.update({
443
- "emb_dim": self.emb_dim,
444
- "T": self.T,
445
- "context_size": self.context_size,
446
- "vocab_size": self.vocab_size})
447
- return config
448
-
449
- # @title DiT Block
450
- class Gain(tf.keras.layers.Layer):
451
- def __init__(self):
452
- super(Gain, self).__init__()
453
-
454
- def build(self, input_shape):
455
- self.M = Dense(input_shape[2], use_bias=True,kernel_initializer='glorot_uniform')
456
-
457
- def call(self, x, cond):
458
- scale = self.M(cond)
459
- x *= tf.expand_dims(scale, 1)
460
- return x
461
-
462
- @register_keras_serializable()
463
- class DiTBlock(tf.keras.layers.Layer):
464
- def __init__(self, hidden_dim, heads, context_size, **kwargs):
465
- super().__init__(**kwargs)
466
- self.emb_dim = hidden_dim
467
- self.heads = heads
468
- self.context_size = context_size
469
- self.gain1 = Gain()
470
- self.gain2 = Gain()
471
- self.adaLN1 = AdaptiveLayerNorm()
472
-
473
- self.attn = MultiHeadAttention(num_heads=self.heads, key_dim=self.emb_dim//self.heads, output_shape=self.emb_dim)
474
- self.adaLN2 = AdaptiveLayerNorm()
475
- self.mlp_up = Dense(self.emb_dim*4, activation="gelu")
476
- self.mlp_down = Dense(self.emb_dim)
477
-
478
- def call(self, x, cond):
479
- R = self.adaLN1(x, cond)
480
- R = self.gain1(self.attn(R, R, R), cond)
481
- x = x + R
482
- R = self.adaLN2(x, cond)
483
- R = self.mlp_up(R)
484
- R = self.gain2(self.mlp_down(R), cond)
485
- x = x + R
486
- return x
487
-
488
- def get_config(self):
489
- config = super().get_config()
490
- config.update({"hidden_dim": self.emb_dim,
491
- "heads": self.heads,
492
- "context_size": self.context_size})
493
- return config
494
-
495
- encoder = tf.keras.models.load_model("encoder.keras")
496
- decoder = tf.keras.models.load_model("decoder.keras")
497
- diffuser = tf.keras.models.load_model("diffusion-med-coco.keras")
498
-
499
- def inference(prompts):
500
- N = len(prompts)
501
- x_t = tf.random.normal(shape=(N, 32, 32, 4))
502
- texts = tf.convert_to_tensor([process_text(p) for p in prompts])
503
- t_shape = (N, 1)
504
-
505
- for t in reversed(range(T)):
506
- t_batch = tf.convert_to_tensor([[t]] * N)
507
- eps_hat = diffuser([x_t, texts, t_batch])
508
- x_t = tf.convert_to_tensor(denoise_step(x_t.numpy(), eps_hat.numpy(), t, a, beta), dtype=tf.float32)
509
-
510
- x_0 = x_t.numpy()
511
- imgs = decoder(x_0)
512
- return imgs
513
-
514
- class EndpointHandler:
515
- def __init__(self, path="."):
516
- pass # models already loaded above
517
-
518
- def __call__(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
519
- prompts = inputs["inputs"]
520
- N = len(prompts)
521
- x_t = tf.random.normal(shape=(N, *latent_shape))
522
- texts = tf.convert_to_tensor([process_text(p) for p in prompts])
523
-
524
- for t in reversed(range(T)):
525
- t_batch = tf.convert_to_tensor([[t]] * N)
526
- eps_hat = diffuser([x_t, texts, t_batch])
527
- x_t = tf.convert_to_tensor(
528
- denoise_step(x_t.numpy(), eps_hat.numpy(), t, a, beta), dtype=tf.float32
529
- )
530
-
531
- imgs = decoder(x_t)
532
  return {"outputs": imgs.numpy().tolist()}
 
1
+ from transformers import PreTrainedTokenizerFast
2
+
3
+ token2vec = PreTrainedTokenizerFast.from_pretrained("cl100k_tokenizer")
4
+
5
+ from typing import Dict, Any
6
+ import numpy as np
7
+ import tensorflow as tf
8
+ from tensorflow.keras.layers import Dense, LayerNormalization, Conv2D, UpSampling2D, Embedding, MultiHeadAttention
9
+ from tensorflow.keras.saving import register_keras_serializable
10
+ import tensorflow as tf
11
+
12
+ # @title Config
13
+ def small_config():
14
+ T = 500
15
+ beta = np.linspace(1e-4, 0.02, T)
16
+ alpha = 1 - beta
17
+ a = np.cumprod(alpha)
18
+
19
+ return {
20
+ "filters": [128, 256],
21
+ "hidden_dim": 384,
22
+ "heads": 6,
23
+ "layers": 8,
24
+ "patch_size": 4,
25
+ "batch_size": 64,
26
+ "T": T,
27
+ "context_size": 8,
28
+ "image_size": 128,
29
+ "latent_shape": (32, 32, 4),
30
+ "beta": beta,
31
+ "alpha": alpha,
32
+ "a": a}
33
+
34
+ def med_config():
35
+ T = 1000
36
+ beta = np.linspace(1e-4, 0.02, T)
37
+ alpha = 1 - beta
38
+ a = np.cumprod(alpha)
39
+
40
+ return {
41
+ "filters": [128, 256],
42
+ "hidden_dim": 768,
43
+ "heads": 12,
44
+ "layers": 12,
45
+ "patch_size": 4,
46
+ "batch_size": 64,
47
+ "T": T,
48
+ "context_size": 8,
49
+ "image_size": 128,
50
+ "latent_shape": (32, 32, 4),
51
+ "beta": beta,
52
+ "alpha": alpha,
53
+ "a": a}
54
+
55
+ def large_config():
56
+ T = 1000
57
+ beta = np.linspace(1e-4, 0.02, T)
58
+ alpha = 1 - beta
59
+ a = np.cumprod(alpha)
60
+
61
+ return {
62
+ "filters": [128, 256],
63
+ "hidden_dim": 1024,
64
+ "heads": 16,
65
+ "layers": 24,
66
+ "patch_size": 4,
67
+ "batch_size": 64,
68
+ "T": T,
69
+ "context_size": 8,
70
+ "image_size": 128,
71
+ "latent_shape": (32, 32, 4),
72
+ "beta": beta,
73
+ "alpha": alpha,
74
+ "a": a}
75
+
76
+ config = med_config()
77
+
78
+ filters = config['filters']
79
+ hidden_dim = config['hidden_dim']
80
+ heads = config['heads']
81
+ layers = config['layers']
82
+ patch_size = config['patch_size']
83
+ batch_size = config['batch_size']
84
+ T = config['T']
85
+ context_size = config['context_size']
86
+ image_size = config['image_size']
87
+ latent_shape = config['latent_shape']
88
+ beta = config['beta']
89
+ alpha = config['alpha']
90
+ a = config['a']
91
+
92
+ # @title ResBlock, UpBlock, DownBlock
93
+ @register_keras_serializable()
94
+ class ResBlock(tf.keras.layers.Layer):
95
+ def __init__(self, filters, p, **kwargs):
96
+ super(ResBlock, self).__init__(**kwargs)
97
+ self.filters = filters
98
+ self.p = p
99
+ self.reshape = Conv2D(filters, kernel_size=1, strides=1, padding="same")
100
+ #self.norm = BatchNormalization(center=False, scale=False)
101
+ self.conv1 = Conv2D(filters, kernel_size=p, strides=1, padding="same", activation="swish")
102
+ self.conv2 = Conv2D(filters, kernel_size=p, strides=1, padding="same")
103
+
104
+ def call(self, x):
105
+ x = self.reshape(x)
106
+ resid = x
107
+ #resid = self.norm(resid)
108
+ resid = self.conv1(resid)
109
+ resid = self.conv2(resid)
110
+ x = x + resid
111
+ return x
112
+
113
+ def get_config(self):
114
+ config = super().get_config()
115
+ config.update({
116
+ "filters": self.filters,
117
+ "p": self.p})
118
+ return config
119
+
120
+ @register_keras_serializable()
121
+ class DownBlock(tf.keras.layers.Layer):
122
+ def __init__(self, filters, **kwargs):
123
+ super(DownBlock, self).__init__(**kwargs)
124
+ self.filters = filters
125
+ self.resBlocks = [ResBlock(f, p=3) for f in filters]
126
+ self.pool = tf.keras.layers.MaxPool2D(pool_size=(2, 2))
127
+
128
+ def call(self, x):
129
+ for resBlock in self.resBlocks:
130
+ x = resBlock(x)
131
+ x = self.pool(x)
132
+ return x
133
+
134
+ def get_config(self):
135
+ config = super().get_config()
136
+ config.update({
137
+ "filters": self.filters})
138
+ return config
139
+
140
+ @register_keras_serializable()
141
+ class UpBlock(tf.keras.layers.Layer):
142
+ def __init__(self, filters, **kwargs):
143
+ super(UpBlock, self).__init__(**kwargs)
144
+ self.filters = filters
145
+ self.resBlocks = [ResBlock(f, p=3) for f in filters]
146
+ self.upSample = UpSampling2D(size=2, interpolation="bilinear")
147
+
148
+ def call(self, x):
149
+ x = self.upSample(x)
150
+ for resBlock in self.resBlocks:
151
+ x = resBlock(x)
152
+ return x
153
+
154
+ def get_config(self):
155
+ config = super().get_config()
156
+ config.update({
157
+ "filters": self.filters})
158
+ return config
159
+
160
+ # @title Encoder, Decoder
161
+ @register_keras_serializable()
162
+ class Encoder(tf.keras.Model):
163
+ def __init__(self, filters, latent_dim, **kwargs):
164
+ super(Encoder, self).__init__(**kwargs)
165
+ self.filters = filters
166
+ self.latent_dim = latent_dim
167
+ self.downBlocks = [DownBlock([f,f]) for f in filters]
168
+ self.latent_proj = Conv2D(latent_dim * 2, kernel_size=1, strides=1, padding="same", activation="linear")
169
+
170
+ @tf.function
171
+ def sample(self, mu, logvar):
172
+ eps = tf.random.normal(shape=tf.shape(mu))
173
+ return eps * tf.exp(logvar * .5) + mu
174
+
175
+ def call(self, x, training=1):
176
+ for downBlock in self.downBlocks:
177
+ x = downBlock(x)
178
+ x = self.latent_proj(x)
179
+ mu, logvar = tf.split(x, 2, axis=-1)
180
+ z = self.sample(mu, logvar)
181
+ return z, mu, logvar
182
+
183
+ def get_config(self):
184
+ config = super().get_config()
185
+ config.update({
186
+ "filters": self.filters,
187
+ "latent_dim": self.latent_dim})
188
+ return config
189
+
190
+ def compute_output_shape(self, input_shape):
191
+ return (input_shape[0], self.latent_dim), (input_shape[0], self.latent_dim), (input_shape[0], self.latent_dim)
192
+
193
+ @register_keras_serializable()
194
+ class Decoder(tf.keras.Model):
195
+ def __init__(self, filters, img_size, **kwargs):
196
+ super(Decoder, self).__init__(**kwargs)
197
+ self.filters = filters[::-1]
198
+ self.img_size = img_size
199
+ self.undo_latent_proj = Conv2D(filters[0], kernel_size=1, strides=1, padding="same")
200
+ self.upBlocks = [UpBlock([f,f]) for f in filters]
201
+ self.conv_proj = Conv2D(3, kernel_size=3, padding="same", activation="linear")
202
+
203
+ def call(self, z, training=1):
204
+ z = self.undo_latent_proj(z)
205
+ for upBlock in self.upBlocks:
206
+ z = upBlock(z)
207
+ x = self.conv_proj(z)
208
+ return x
209
+
210
+ def get_config(self):
211
+ config = super().get_config()
212
+ config.update({
213
+ "filters": self.filters[::-1],
214
+ "img_size": self.img_size})
215
+ return config
216
+
217
+ def compute_output_shape(self, input_shape):
218
+ return (input_shape[0], self.img_size, self.img_size, 3)
219
+
220
+ # @title Helper Functions
221
+ def process_text(text):
222
+ import tiktoken
223
+ tokenizer = tiktoken.encoding_for_model("gpt-3.5-turbo")
224
+ tokens = tokenizer.encode(text)
225
+ while len(tokens) < context_size:
226
+ tokens.append(0)
227
+ return tokens[:context_size]
228
+
229
+ def normalise_img(img_tensor): # Maps [-1,1] to [0,1]
230
+ img = img_tensor
231
+ img *= 0.5
232
+ img += 0.5
233
+ return img
234
+
235
+ def prep_img(img_tensor): # Maps [0,255] to [-1,1]
236
+ img = img_tensor.copy()
237
+ img = img / 127.5
238
+ img -= 1
239
+ return img
240
+
241
+ def noisify_img(img_tensor, t, a): # Returns x_t and the noise used
242
+ epsilon = np.random.normal(0, 1, img_tensor.shape).astype(np.float32) # Standard normal
243
+ sqrt_alpha_bar = np.sqrt(a[t])
244
+ sqrt_one_minus_alpha_bar = np.sqrt(1 - a[t])
245
+ x_t = sqrt_alpha_bar * img_tensor + sqrt_one_minus_alpha_bar * epsilon
246
+ return x_t, epsilon
247
+
248
+ def denoise_step(x_t, eps_hat, t, a, beta):
249
+ """
250
+ Reverse one DDPM step: x_t β†’ x_{t-1}
251
+ """
252
+ a_bar_t = tf.convert_to_tensor(a[t], dtype=tf.float32)
253
+ a_bar_prev = tf.convert_to_tensor(a[t - 1] if t > 0 else 1.0, dtype=tf.float32)
254
+ a_t = a_bar_t / a_bar_prev
255
+ beta_t = tf.convert_to_tensor(beta[t], dtype=tf.float32)
256
+
257
+ # Avoid NaNs with clamping
258
+ sqrt_recip_a_t = tf.math.rsqrt(tf.maximum(a_t, 1e-5))
259
+ sqrt_one_minus_ab = tf.sqrt(tf.maximum(1. - a_bar_t, 1e-5))
260
+
261
+ eps_term = (beta_t / sqrt_one_minus_ab) * eps_hat
262
+ mean = sqrt_recip_a_t * (x_t - eps_term)
263
+
264
+ if t > 1:
265
+ noise = tf.random.normal(shape=x_t.shape)
266
+ sigma = tf.sqrt(tf.maximum(beta_t, 1e-5))
267
+ x_prev = mean + sigma * noise
268
+ else:
269
+ x_prev = mean
270
+
271
+ return x_prev
272
+
273
+ # @title Transformer Block
274
+ @register_keras_serializable()
275
+ class TransformerBlock(tf.keras.Layer):
276
+ def __init__(self, context_size, head_no, latent_dim, **kwargs):
277
+ super().__init__(**kwargs)
278
+ self.context_size = context_size
279
+ self.head_no = head_no
280
+ self.latent_dim = latent_dim
281
+ self.attn = MultiHeadAttention(num_heads=head_no, key_dim=latent_dim//head_no, output_shape=latent_dim)
282
+ self.mlp_up = Dense(latent_dim*4, activation="gelu")
283
+ self.mlp_down = Dense(latent_dim)
284
+ self.norm1 = LayerNormalization()
285
+ self.norm2 = LayerNormalization()
286
+
287
+ def call(self, x):
288
+ normed = self.norm1(x)
289
+ x = x + self.attn(normed, normed, normed)
290
+ normed = self.norm2(x)
291
+ dx = self.mlp_up(normed)
292
+ x = x + self.mlp_down(dx)
293
+ return x
294
+
295
+ def build(self, input_shape):
296
+ super().build(input_shape)
297
+
298
+ def compute_output_shape(self, input_shape):
299
+ return input_shape
300
+
301
+ def get_config(self):
302
+ config = super().get_config()
303
+ config.update({
304
+ "context_size": self.context_size,
305
+ "head_no": self.head_no,
306
+ "latent_dim": self.latent_dim})
307
+ return config
308
+
309
+ # @title AdaLN-Zero
310
+ @register_keras_serializable()
311
+ class AdaptiveLayerNorm(tf.keras.Layer):
312
+ def __init__(self, eps=1e-6,**kwargs):
313
+ self.layernorm = LayerNormalization(epsilon=eps,center=False, scale=False)
314
+ super(AdaptiveLayerNorm, self).__init__(**kwargs)
315
+
316
+ def build(self, input_shape):
317
+ #B, num_patches, hidden_dim
318
+ self.M = Dense(input_shape[2], use_bias=True, kernel_initializer='glorot_uniform', activation="linear")
319
+ self.b = Dense(input_shape[2], use_bias=True, kernel_initializer='glorot_uniform', activation="linear")
320
+
321
+ def call(self, x, cond):
322
+ gamma = self.M(cond)
323
+ beta = self.b(cond)
324
+ x = self.layernorm(x)
325
+ x = x * (1 + tf.expand_dims(gamma, 1)) + tf.expand_dims(beta, 1)
326
+ return x
327
+
328
+ def get_config(self):
329
+ config = super().get_config()
330
+ return config
331
+
332
+ # @title Image Embedder, Unembedder
333
+ @register_keras_serializable()
334
+ class ImageEmbedder(tf.keras.Layer):
335
+ def __init__(self, latent_size, patch_size, emb_dim,**kwargs):
336
+ super().__init__(**kwargs)
337
+ self.emb_dim = emb_dim
338
+ self.patch_size = patch_size
339
+ self.latent_size = latent_size
340
+ self.pos_emb = Embedding(input_dim=(latent_size // patch_size)**2 , output_dim=emb_dim, embeddings_initializer="glorot_uniform")
341
+ self.reshaper = Dense(emb_dim, kernel_initializer="glorot_uniform")
342
+ self.conv_expansion = Conv2D(emb_dim, kernel_size=patch_size, strides=patch_size, padding="same")
343
+
344
+ def call(self, x):
345
+ x = self.reshaper(x)
346
+ x = self.conv_expansion(x)
347
+ x = tf.reshape(x, shape=[tf.shape(x)[0], tf.shape(x)[1]*tf.shape(x)[2], tf.shape(x)[3]])
348
+ positions = tf.range(start=0, limit=(self.latent_size // self.patch_size)**2, delta=1)
349
+ embeddings = self.pos_emb(positions)
350
+ x = embeddings + x
351
+ return x
352
+
353
+ def get_config(self):
354
+ config = super().get_config()
355
+ config.update({
356
+ "latent_size" : self.latent_size,
357
+ "patch_size": self.patch_size,
358
+ "emb_dim": self.emb_dim})
359
+ return config
360
+
361
+ @register_keras_serializable()
362
+ class ImageUnembedder(tf.keras.Layer):
363
+ def __init__(self, latent_size, patch_size, latent_dim, **kwargs):
364
+ super().__init__(**kwargs)
365
+ self.latent_dim = latent_dim
366
+ self.patch_size = patch_size
367
+ self.latent_size = latent_size
368
+ self.AdaLN = AdaptiveLayerNorm()
369
+ self.reshape_to_latent = Dense(patch_size*patch_size*latent_dim, kernel_initializer="glorot_uniform")
370
+
371
+ def call(self, x, cond):
372
+ x = self.AdaLN(x, cond)
373
+ x = self.reshape_to_latent(x)
374
+ x = tf.reshape(x, shape=
375
+ [tf.shape(x)[0],
376
+ self.latent_size // self.patch_size,
377
+ self.latent_size // self.patch_size,
378
+ self.latent_dim*(self.patch_size**2)])
379
+ x = tf.nn.depth_to_space(x, block_size=self.patch_size)
380
+ return x
381
+
382
+ def get_config(self):
383
+ config = super().get_config()
384
+ config.update({
385
+ "latent_size" : self.latent_size,
386
+ "patch_size": self.patch_size,
387
+ "latent_dim": self.latent_dim})
388
+ return config
389
+
390
+ # @title LEGACY Prompt and Timestep Embedder
391
+ @register_keras_serializable()
392
+ class ConditioningEmbedder(tf.keras.layers.Layer):
393
+ def __init__(self, emb_dim, T, context_size, vocab_size=100266, **kwargs):
394
+ super().__init__(**kwargs)
395
+ self.emb_dim = emb_dim
396
+ self.T = T
397
+ self.context_size = context_size
398
+ self.vocab_size = vocab_size
399
+ positions = tf.range(T, dtype=tf.float32)[:, tf.newaxis]
400
+ frequencies = tf.constant(10000 ** (-tf.range(0, emb_dim, 2, dtype=tf.float32) / emb_dim))
401
+ angle_rates = positions * frequencies # (T, emb_dim/2)
402
+ sin_part = tf.sin(angle_rates)
403
+ cos_part = tf.cos(angle_rates)
404
+ emb = tf.stack([sin_part, cos_part], axis=-1) # (T, emb_dim/2, 2)
405
+ emb = tf.reshape(emb, [T, emb_dim]) # (T, emb_dim)
406
+ self.t_embeddings = tf.constant(emb, dtype=tf.float32)
407
+
408
+ self.prompt_emb = self.add_weight(shape=(vocab_size, emb_dim), initializer='glorot_uniform', name='prompt_emb', trainable=True)
409
+ self.CLS = self.add_weight(shape=(emb_dim,), initializer='glorot_uniform', name='CLS', trainable=True)
410
+ self.prompt_pos_enc = self.add_weight(shape=(1, context_size+1, emb_dim), initializer='glorot_uniform', name='prompt_pos_enc', trainable=True)
411
+ self.transformer = TransformerBlock(context_size+1, head_no=6, latent_dim=emb_dim)
412
+
413
+ def call(self, x):
414
+ t, prompt_tokens = x
415
+
416
+ # ── timestep embedding ───────────────────────────
417
+ t = tf.cast(tf.squeeze(t, axis=-1), tf.int32) # (batch,)
418
+ embedded_t = tf.gather(self.t_embeddings, t) # (batch, emb_dim)
419
+ embedded_t = embedded_t[:, tf.newaxis, :] # (batch, 1, emb_dim)
420
+
421
+ # ── prompt embedding path ─────────────────────────
422
+ embedded_prompt = tf.nn.embedding_lookup(
423
+ self.prompt_emb, prompt_tokens) # (batch, seq_len, emb_dim)
424
+
425
+ cls_tok = tf.tile(self.CLS[None, None, :],
426
+ [tf.shape(embedded_prompt)[0], 1, 1])
427
+ embedded_prompt = tf.concat([cls_tok, embedded_prompt], axis=1)
428
+ embedded_prompt += self.prompt_pos_enc
429
+ embedded_prompt = self.transformer(embedded_prompt) # (batch, seq_len+1, emb_dim)
430
+
431
+ # add t-embedding to every token (broadcasts along axis-1)
432
+ embedded_prompt += embedded_t
433
+
434
+ # return CLS (keep singleton axis if you need it)
435
+ return embedded_prompt[:, 0, :] # (batch, 1, emb_dim)
436
+
437
+
438
+ def get_config(self):
439
+ config = super().get_config()
440
+ config.update({
441
+ "emb_dim": self.emb_dim,
442
+ "T": self.T,
443
+ "context_size": self.context_size,
444
+ "vocab_size": self.vocab_size})
445
+ return config
446
+
447
+ # @title DiT Block
448
+ class Gain(tf.keras.layers.Layer):
449
+ def __init__(self):
450
+ super(Gain, self).__init__()
451
+
452
+ def build(self, input_shape):
453
+ self.M = Dense(input_shape[2], use_bias=True,kernel_initializer='glorot_uniform')
454
+
455
+ def call(self, x, cond):
456
+ scale = self.M(cond)
457
+ x *= tf.expand_dims(scale, 1)
458
+ return x
459
+
460
+ @register_keras_serializable()
461
+ class DiTBlock(tf.keras.layers.Layer):
462
+ def __init__(self, hidden_dim, heads, context_size, **kwargs):
463
+ super().__init__(**kwargs)
464
+ self.emb_dim = hidden_dim
465
+ self.heads = heads
466
+ self.context_size = context_size
467
+ self.gain1 = Gain()
468
+ self.gain2 = Gain()
469
+ self.adaLN1 = AdaptiveLayerNorm()
470
+
471
+ self.attn = MultiHeadAttention(num_heads=self.heads, key_dim=self.emb_dim//self.heads, output_shape=self.emb_dim)
472
+ self.adaLN2 = AdaptiveLayerNorm()
473
+ self.mlp_up = Dense(self.emb_dim*4, activation="gelu")
474
+ self.mlp_down = Dense(self.emb_dim)
475
+
476
+ def call(self, x, cond):
477
+ R = self.adaLN1(x, cond)
478
+ R = self.gain1(self.attn(R, R, R), cond)
479
+ x = x + R
480
+ R = self.adaLN2(x, cond)
481
+ R = self.mlp_up(R)
482
+ R = self.gain2(self.mlp_down(R), cond)
483
+ x = x + R
484
+ return x
485
+
486
+ def get_config(self):
487
+ config = super().get_config()
488
+ config.update({"hidden_dim": self.emb_dim,
489
+ "heads": self.heads,
490
+ "context_size": self.context_size})
491
+ return config
492
+
493
+ encoder = tf.keras.models.load_model("encoder.keras")
494
+ decoder = tf.keras.models.load_model("decoder.keras")
495
+ diffuser = tf.keras.models.load_model("diffusion-med-coco.keras")
496
+
497
+ def inference(prompts):
498
+ N = len(prompts)
499
+ x_t = tf.random.normal(shape=(N, 32, 32, 4))
500
+ texts = tf.convert_to_tensor([process_text(p) for p in prompts])
501
+ t_shape = (N, 1)
502
+
503
+ for t in reversed(range(T)):
504
+ t_batch = tf.convert_to_tensor([[t]] * N)
505
+ eps_hat = diffuser([x_t, texts, t_batch])
506
+ x_t = tf.convert_to_tensor(denoise_step(x_t.numpy(), eps_hat.numpy(), t, a, beta), dtype=tf.float32)
507
+
508
+ x_0 = x_t.numpy()
509
+ imgs = decoder(x_0)
510
+ return imgs
511
+
512
+ class EndpointHandler:
513
+ def __init__(self, path="."):
514
+ pass # models already loaded above
515
+
516
+ def __call__(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
517
+ prompts = inputs["inputs"]
518
+ N = len(prompts)
519
+ x_t = tf.random.normal(shape=(N, *latent_shape))
520
+ texts = tf.convert_to_tensor([process_text(p) for p in prompts])
521
+
522
+ for t in reversed(range(T)):
523
+ t_batch = tf.convert_to_tensor([[t]] * N)
524
+ eps_hat = diffuser([x_t, texts, t_batch])
525
+ x_t = tf.convert_to_tensor(
526
+ denoise_step(x_t.numpy(), eps_hat.numpy(), t, a, beta), dtype=tf.float32
527
+ )
528
+
529
+ imgs = decoder(x_t)
 
 
530
  return {"outputs": imgs.numpy().tolist()}