detectivejoewest commited on
Commit
fed1d4b
Β·
verified Β·
1 Parent(s): 1e4231c

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +529 -0
handler.py ADDED
@@ -0,0 +1,529 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tiktoken
2
+ from typing import Dict, Any
3
+ import numpy as np
4
+ import tensorflow as tf
5
+ from tensorflow.keras.layers import Dense, LayerNormalization, Conv2D, UpSampling2D, Embedding, MultiHeadAttention
6
+ from tensorflow.keras.saving import register_keras_serializable
7
+ import tensorflow as tf
8
+
9
+ token2vec = tiktoken.encoding_for_model("gpt-3.5-turbo")
10
+
11
+ # @title Config
12
+ def small_config():
13
+ T = 500
14
+ beta = np.linspace(1e-4, 0.02, T)
15
+ alpha = 1 - beta
16
+ a = np.cumprod(alpha)
17
+
18
+ return {
19
+ "filters": [128, 256],
20
+ "hidden_dim": 384,
21
+ "heads": 6,
22
+ "layers": 8,
23
+ "patch_size": 4,
24
+ "batch_size": 64,
25
+ "T": T,
26
+ "context_size": 8,
27
+ "image_size": 128,
28
+ "latent_shape": (32, 32, 4),
29
+ "beta": beta,
30
+ "alpha": alpha,
31
+ "a": a}
32
+
33
+ def med_config():
34
+ T = 1000
35
+ beta = np.linspace(1e-4, 0.02, T)
36
+ alpha = 1 - beta
37
+ a = np.cumprod(alpha)
38
+
39
+ return {
40
+ "filters": [128, 256],
41
+ "hidden_dim": 768,
42
+ "heads": 12,
43
+ "layers": 12,
44
+ "patch_size": 4,
45
+ "batch_size": 64,
46
+ "T": T,
47
+ "context_size": 8,
48
+ "image_size": 128,
49
+ "latent_shape": (32, 32, 4),
50
+ "beta": beta,
51
+ "alpha": alpha,
52
+ "a": a}
53
+
54
+ def large_config():
55
+ T = 1000
56
+ beta = np.linspace(1e-4, 0.02, T)
57
+ alpha = 1 - beta
58
+ a = np.cumprod(alpha)
59
+
60
+ return {
61
+ "filters": [128, 256],
62
+ "hidden_dim": 1024,
63
+ "heads": 16,
64
+ "layers": 24,
65
+ "patch_size": 4,
66
+ "batch_size": 64,
67
+ "T": T,
68
+ "context_size": 8,
69
+ "image_size": 128,
70
+ "latent_shape": (32, 32, 4),
71
+ "beta": beta,
72
+ "alpha": alpha,
73
+ "a": a}
74
+
75
+ config = med_config()
76
+
77
+ filters = config['filters']
78
+ hidden_dim = config['hidden_dim']
79
+ heads = config['heads']
80
+ layers = config['layers']
81
+ patch_size = config['patch_size']
82
+ batch_size = config['batch_size']
83
+ T = config['T']
84
+ context_size = config['context_size']
85
+ image_size = config['image_size']
86
+ latent_shape = config['latent_shape']
87
+ beta = config['beta']
88
+ alpha = config['alpha']
89
+ a = config['a']
90
+
91
+ # @title ResBlock, UpBlock, DownBlock
92
+ @register_keras_serializable()
93
+ class ResBlock(tf.keras.layers.Layer):
94
+ def __init__(self, filters, p, **kwargs):
95
+ super(ResBlock, self).__init__(**kwargs)
96
+ self.filters = filters
97
+ self.p = p
98
+ self.reshape = Conv2D(filters, kernel_size=1, strides=1, padding="same")
99
+ #self.norm = BatchNormalization(center=False, scale=False)
100
+ self.conv1 = Conv2D(filters, kernel_size=p, strides=1, padding="same", activation="swish")
101
+ self.conv2 = Conv2D(filters, kernel_size=p, strides=1, padding="same")
102
+
103
+ def call(self, x):
104
+ x = self.reshape(x)
105
+ resid = x
106
+ #resid = self.norm(resid)
107
+ resid = self.conv1(resid)
108
+ resid = self.conv2(resid)
109
+ x = x + resid
110
+ return x
111
+
112
+ def get_config(self):
113
+ config = super().get_config()
114
+ config.update({
115
+ "filters": self.filters,
116
+ "p": self.p})
117
+ return config
118
+
119
+ @register_keras_serializable()
120
+ class DownBlock(tf.keras.layers.Layer):
121
+ def __init__(self, filters, **kwargs):
122
+ super(DownBlock, self).__init__(**kwargs)
123
+ self.filters = filters
124
+ self.resBlocks = [ResBlock(f, p=3) for f in filters]
125
+ self.pool = tf.keras.layers.MaxPool2D(pool_size=(2, 2))
126
+
127
+ def call(self, x):
128
+ for resBlock in self.resBlocks:
129
+ x = resBlock(x)
130
+ x = self.pool(x)
131
+ return x
132
+
133
+ def get_config(self):
134
+ config = super().get_config()
135
+ config.update({
136
+ "filters": self.filters})
137
+ return config
138
+
139
+ @register_keras_serializable()
140
+ class UpBlock(tf.keras.layers.Layer):
141
+ def __init__(self, filters, **kwargs):
142
+ super(UpBlock, self).__init__(**kwargs)
143
+ self.filters = filters
144
+ self.resBlocks = [ResBlock(f, p=3) for f in filters]
145
+ self.upSample = UpSampling2D(size=2, interpolation="bilinear")
146
+
147
+ def call(self, x):
148
+ x = self.upSample(x)
149
+ for resBlock in self.resBlocks:
150
+ x = resBlock(x)
151
+ return x
152
+
153
+ def get_config(self):
154
+ config = super().get_config()
155
+ config.update({
156
+ "filters": self.filters})
157
+ return config
158
+
159
+ # @title Encoder, Decoder
160
+ @register_keras_serializable()
161
+ class Encoder(tf.keras.Model):
162
+ def __init__(self, filters, latent_dim, **kwargs):
163
+ super(Encoder, self).__init__(**kwargs)
164
+ self.filters = filters
165
+ self.latent_dim = latent_dim
166
+ self.downBlocks = [DownBlock([f,f]) for f in filters]
167
+ self.latent_proj = Conv2D(latent_dim * 2, kernel_size=1, strides=1, padding="same", activation="linear")
168
+
169
+ @tf.function
170
+ def sample(self, mu, logvar):
171
+ eps = tf.random.normal(shape=tf.shape(mu))
172
+ return eps * tf.exp(logvar * .5) + mu
173
+
174
+ def call(self, x, training=1):
175
+ for downBlock in self.downBlocks:
176
+ x = downBlock(x)
177
+ x = self.latent_proj(x)
178
+ mu, logvar = tf.split(x, 2, axis=-1)
179
+ z = self.sample(mu, logvar)
180
+ return z, mu, logvar
181
+
182
+ def get_config(self):
183
+ config = super().get_config()
184
+ config.update({
185
+ "filters": self.filters,
186
+ "latent_dim": self.latent_dim})
187
+ return config
188
+
189
+ def compute_output_shape(self, input_shape):
190
+ return (input_shape[0], self.latent_dim), (input_shape[0], self.latent_dim), (input_shape[0], self.latent_dim)
191
+
192
+ @register_keras_serializable()
193
+ class Decoder(tf.keras.Model):
194
+ def __init__(self, filters, img_size, **kwargs):
195
+ super(Decoder, self).__init__(**kwargs)
196
+ self.filters = filters[::-1]
197
+ self.img_size = img_size
198
+ self.undo_latent_proj = Conv2D(filters[0], kernel_size=1, strides=1, padding="same")
199
+ self.upBlocks = [UpBlock([f,f]) for f in filters]
200
+ self.conv_proj = Conv2D(3, kernel_size=3, padding="same", activation="linear")
201
+
202
+ def call(self, z, training=1):
203
+ z = self.undo_latent_proj(z)
204
+ for upBlock in self.upBlocks:
205
+ z = upBlock(z)
206
+ x = self.conv_proj(z)
207
+ return x
208
+
209
+ def get_config(self):
210
+ config = super().get_config()
211
+ config.update({
212
+ "filters": self.filters[::-1],
213
+ "img_size": self.img_size})
214
+ return config
215
+
216
+ def compute_output_shape(self, input_shape):
217
+ return (input_shape[0], self.img_size, self.img_size, 3)
218
+
219
+ # @title Helper Functions
220
+ def process_text(text):
221
+ import tiktoken
222
+ tokenizer = tiktoken.encoding_for_model("gpt-3.5-turbo")
223
+ tokens = tokenizer.encode(text)
224
+ while len(tokens) < context_size:
225
+ tokens.append(0)
226
+ return tokens[:context_size]
227
+
228
+ def normalise_img(img_tensor): # Maps [-1,1] to [0,1]
229
+ img = img_tensor
230
+ img *= 0.5
231
+ img += 0.5
232
+ return img
233
+
234
+ def prep_img(img_tensor): # Maps [0,255] to [-1,1]
235
+ img = img_tensor.copy()
236
+ img = img / 127.5
237
+ img -= 1
238
+ return img
239
+
240
+ def noisify_img(img_tensor, t, a): # Returns x_t and the noise used
241
+ epsilon = np.random.normal(0, 1, img_tensor.shape).astype(np.float32) # Standard normal
242
+ sqrt_alpha_bar = np.sqrt(a[t])
243
+ sqrt_one_minus_alpha_bar = np.sqrt(1 - a[t])
244
+ x_t = sqrt_alpha_bar * img_tensor + sqrt_one_minus_alpha_bar * epsilon
245
+ return x_t, epsilon
246
+
247
+ def denoise_step(x_t, eps_hat, t, a, beta):
248
+ """
249
+ Reverse one DDPM step: x_t β†’ x_{t-1}
250
+ """
251
+ a_bar_t = tf.convert_to_tensor(a[t], dtype=tf.float32)
252
+ a_bar_prev = tf.convert_to_tensor(a[t - 1] if t > 0 else 1.0, dtype=tf.float32)
253
+ a_t = a_bar_t / a_bar_prev
254
+ beta_t = tf.convert_to_tensor(beta[t], dtype=tf.float32)
255
+
256
+ # Avoid NaNs with clamping
257
+ sqrt_recip_a_t = tf.math.rsqrt(tf.maximum(a_t, 1e-5))
258
+ sqrt_one_minus_ab = tf.sqrt(tf.maximum(1. - a_bar_t, 1e-5))
259
+
260
+ eps_term = (beta_t / sqrt_one_minus_ab) * eps_hat
261
+ mean = sqrt_recip_a_t * (x_t - eps_term)
262
+
263
+ if t > 1:
264
+ noise = tf.random.normal(shape=x_t.shape)
265
+ sigma = tf.sqrt(tf.maximum(beta_t, 1e-5))
266
+ x_prev = mean + sigma * noise
267
+ else:
268
+ x_prev = mean
269
+
270
+ return x_prev
271
+
272
+ # @title Transformer Block
273
+ @register_keras_serializable()
274
+ class TransformerBlock(tf.keras.Layer):
275
+ def __init__(self, context_size, head_no, latent_dim, **kwargs):
276
+ super().__init__(**kwargs)
277
+ self.context_size = context_size
278
+ self.head_no = head_no
279
+ self.latent_dim = latent_dim
280
+ self.attn = MultiHeadAttention(num_heads=head_no, key_dim=latent_dim//head_no, output_shape=latent_dim)
281
+ self.mlp_up = Dense(latent_dim*4, activation="gelu")
282
+ self.mlp_down = Dense(latent_dim)
283
+ self.norm1 = LayerNormalization()
284
+ self.norm2 = LayerNormalization()
285
+
286
+ def call(self, x):
287
+ normed = self.norm1(x)
288
+ x = x + self.attn(normed, normed, normed)
289
+ normed = self.norm2(x)
290
+ dx = self.mlp_up(normed)
291
+ x = x + self.mlp_down(dx)
292
+ return x
293
+
294
+ def build(self, input_shape):
295
+ super().build(input_shape)
296
+
297
+ def compute_output_shape(self, input_shape):
298
+ return input_shape
299
+
300
+ def get_config(self):
301
+ config = super().get_config()
302
+ config.update({
303
+ "context_size": self.context_size,
304
+ "head_no": self.head_no,
305
+ "latent_dim": self.latent_dim})
306
+ return config
307
+
308
+ # @title AdaLN-Zero
309
+ @register_keras_serializable()
310
+ class AdaptiveLayerNorm(tf.keras.Layer):
311
+ def __init__(self, eps=1e-6,**kwargs):
312
+ self.layernorm = LayerNormalization(epsilon=eps,center=False, scale=False)
313
+ super(AdaptiveLayerNorm, self).__init__(**kwargs)
314
+
315
+ def build(self, input_shape):
316
+ #B, num_patches, hidden_dim
317
+ self.M = Dense(input_shape[2], use_bias=True, kernel_initializer='glorot_uniform', activation="linear")
318
+ self.b = Dense(input_shape[2], use_bias=True, kernel_initializer='glorot_uniform', activation="linear")
319
+
320
+ def call(self, x, cond):
321
+ gamma = self.M(cond)
322
+ beta = self.b(cond)
323
+ x = self.layernorm(x)
324
+ x = x * (1 + tf.expand_dims(gamma, 1)) + tf.expand_dims(beta, 1)
325
+ return x
326
+
327
+ def get_config(self):
328
+ config = super().get_config()
329
+ return config
330
+
331
+ # @title Image Embedder, Unembedder
332
+ @register_keras_serializable()
333
+ class ImageEmbedder(tf.keras.Layer):
334
+ def __init__(self, latent_size, patch_size, emb_dim,**kwargs):
335
+ super().__init__(**kwargs)
336
+ self.emb_dim = emb_dim
337
+ self.patch_size = patch_size
338
+ self.latent_size = latent_size
339
+ self.pos_emb = Embedding(input_dim=(latent_size // patch_size)**2 , output_dim=emb_dim, embeddings_initializer="glorot_uniform")
340
+ self.reshaper = Dense(emb_dim, kernel_initializer="glorot_uniform")
341
+ self.conv_expansion = Conv2D(emb_dim, kernel_size=patch_size, strides=patch_size, padding="same")
342
+
343
+ def call(self, x):
344
+ x = self.reshaper(x)
345
+ x = self.conv_expansion(x)
346
+ x = tf.reshape(x, shape=[tf.shape(x)[0], tf.shape(x)[1]*tf.shape(x)[2], tf.shape(x)[3]])
347
+ positions = tf.range(start=0, limit=(self.latent_size // self.patch_size)**2, delta=1)
348
+ embeddings = self.pos_emb(positions)
349
+ x = embeddings + x
350
+ return x
351
+
352
+ def get_config(self):
353
+ config = super().get_config()
354
+ config.update({
355
+ "latent_size" : self.latent_size,
356
+ "patch_size": self.patch_size,
357
+ "emb_dim": self.emb_dim})
358
+ return config
359
+
360
+ @register_keras_serializable()
361
+ class ImageUnembedder(tf.keras.Layer):
362
+ def __init__(self, latent_size, patch_size, latent_dim, **kwargs):
363
+ super().__init__(**kwargs)
364
+ self.latent_dim = latent_dim
365
+ self.patch_size = patch_size
366
+ self.latent_size = latent_size
367
+ self.AdaLN = AdaptiveLayerNorm()
368
+ self.reshape_to_latent = Dense(patch_size*patch_size*latent_dim, kernel_initializer="glorot_uniform")
369
+
370
+ def call(self, x, cond):
371
+ x = self.AdaLN(x, cond)
372
+ x = self.reshape_to_latent(x)
373
+ x = tf.reshape(x, shape=
374
+ [tf.shape(x)[0],
375
+ self.latent_size // self.patch_size,
376
+ self.latent_size // self.patch_size,
377
+ self.latent_dim*(self.patch_size**2)])
378
+ x = tf.nn.depth_to_space(x, block_size=self.patch_size)
379
+ return x
380
+
381
+ def get_config(self):
382
+ config = super().get_config()
383
+ config.update({
384
+ "latent_size" : self.latent_size,
385
+ "patch_size": self.patch_size,
386
+ "latent_dim": self.latent_dim})
387
+ return config
388
+
389
+ # @title LEGACY Prompt and Timestep Embedder
390
+ @register_keras_serializable()
391
+ class ConditioningEmbedder(tf.keras.layers.Layer):
392
+ def __init__(self, emb_dim, T, context_size, vocab_size=100266, **kwargs):
393
+ super().__init__(**kwargs)
394
+ self.emb_dim = emb_dim
395
+ self.T = T
396
+ self.context_size = context_size
397
+ self.vocab_size = vocab_size
398
+ positions = tf.range(T, dtype=tf.float32)[:, tf.newaxis]
399
+ frequencies = tf.constant(10000 ** (-tf.range(0, emb_dim, 2, dtype=tf.float32) / emb_dim))
400
+ angle_rates = positions * frequencies # (T, emb_dim/2)
401
+ sin_part = tf.sin(angle_rates)
402
+ cos_part = tf.cos(angle_rates)
403
+ emb = tf.stack([sin_part, cos_part], axis=-1) # (T, emb_dim/2, 2)
404
+ emb = tf.reshape(emb, [T, emb_dim]) # (T, emb_dim)
405
+ self.t_embeddings = tf.constant(emb, dtype=tf.float32)
406
+
407
+ self.prompt_emb = self.add_weight(shape=(vocab_size, emb_dim), initializer='glorot_uniform', name='prompt_emb', trainable=True)
408
+ self.CLS = self.add_weight(shape=(emb_dim,), initializer='glorot_uniform', name='CLS', trainable=True)
409
+ self.prompt_pos_enc = self.add_weight(shape=(1, context_size+1, emb_dim), initializer='glorot_uniform', name='prompt_pos_enc', trainable=True)
410
+ self.transformer = TransformerBlock(context_size+1, head_no=6, latent_dim=emb_dim)
411
+
412
+ def call(self, x):
413
+ t, prompt_tokens = x
414
+
415
+ # ── timestep embedding ───────────────────────────
416
+ t = tf.cast(tf.squeeze(t, axis=-1), tf.int32) # (batch,)
417
+ embedded_t = tf.gather(self.t_embeddings, t) # (batch, emb_dim)
418
+ embedded_t = embedded_t[:, tf.newaxis, :] # (batch, 1, emb_dim)
419
+
420
+ # ── prompt embedding path ─────────────────────────
421
+ embedded_prompt = tf.nn.embedding_lookup(
422
+ self.prompt_emb, prompt_tokens) # (batch, seq_len, emb_dim)
423
+
424
+ cls_tok = tf.tile(self.CLS[None, None, :],
425
+ [tf.shape(embedded_prompt)[0], 1, 1])
426
+ embedded_prompt = tf.concat([cls_tok, embedded_prompt], axis=1)
427
+ embedded_prompt += self.prompt_pos_enc
428
+ embedded_prompt = self.transformer(embedded_prompt) # (batch, seq_len+1, emb_dim)
429
+
430
+ # add t-embedding to every token (broadcasts along axis-1)
431
+ embedded_prompt += embedded_t
432
+
433
+ # return CLS (keep singleton axis if you need it)
434
+ return embedded_prompt[:, 0, :] # (batch, 1, emb_dim)
435
+
436
+
437
+ def get_config(self):
438
+ config = super().get_config()
439
+ config.update({
440
+ "emb_dim": self.emb_dim,
441
+ "T": self.T,
442
+ "context_size": self.context_size,
443
+ "vocab_size": self.vocab_size})
444
+ return config
445
+
446
+ # @title DiT Block
447
+ class Gain(tf.keras.layers.Layer):
448
+ def __init__(self):
449
+ super(Gain, self).__init__()
450
+
451
+ def build(self, input_shape):
452
+ self.M = Dense(input_shape[2], use_bias=True,kernel_initializer='glorot_uniform')
453
+
454
+ def call(self, x, cond):
455
+ scale = self.M(cond)
456
+ x *= tf.expand_dims(scale, 1)
457
+ return x
458
+
459
+ @register_keras_serializable()
460
+ class DiTBlock(tf.keras.layers.Layer):
461
+ def __init__(self, hidden_dim, heads, context_size, **kwargs):
462
+ super().__init__(**kwargs)
463
+ self.emb_dim = hidden_dim
464
+ self.heads = heads
465
+ self.context_size = context_size
466
+ self.gain1 = Gain()
467
+ self.gain2 = Gain()
468
+ self.adaLN1 = AdaptiveLayerNorm()
469
+
470
+ self.attn = MultiHeadAttention(num_heads=self.heads, key_dim=self.emb_dim//self.heads, output_shape=self.emb_dim)
471
+ self.adaLN2 = AdaptiveLayerNorm()
472
+ self.mlp_up = Dense(self.emb_dim*4, activation="gelu")
473
+ self.mlp_down = Dense(self.emb_dim)
474
+
475
+ def call(self, x, cond):
476
+ R = self.adaLN1(x, cond)
477
+ R = self.gain1(self.attn(R, R, R), cond)
478
+ x = x + R
479
+ R = self.adaLN2(x, cond)
480
+ R = self.mlp_up(R)
481
+ R = self.gain2(self.mlp_down(R), cond)
482
+ x = x + R
483
+ return x
484
+
485
+ def get_config(self):
486
+ config = super().get_config()
487
+ config.update({"hidden_dim": self.emb_dim,
488
+ "heads": self.heads,
489
+ "context_size": self.context_size})
490
+ return config
491
+
492
+ encoder = tf.keras.models.load_model("encoder.keras")
493
+ decoder = tf.keras.models.load_model("decoder.keras")
494
+ diffuser = tf.keras.models.load_model("diffusion-med-coco.keras")
495
+
496
+ def inference(prompts):
497
+ N = len(prompts)
498
+ x_t = tf.random.normal(shape=(N, 32, 32, 4))
499
+ texts = tf.convert_to_tensor([process_text(p) for p in prompts])
500
+ t_shape = (N, 1)
501
+
502
+ for t in reversed(range(T)):
503
+ t_batch = tf.convert_to_tensor([[t]] * N)
504
+ eps_hat = diffuser([x_t, texts, t_batch])
505
+ x_t = tf.convert_to_tensor(denoise_step(x_t.numpy(), eps_hat.numpy(), t, a, beta), dtype=tf.float32)
506
+
507
+ x_0 = x_t.numpy()
508
+ imgs = decoder(x_0)
509
+ return imgs
510
+
511
+ class EndpointHandler:
512
+ def __init__(self, path="."):
513
+ pass # models already loaded above
514
+
515
+ def __call__(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
516
+ prompts = inputs["inputs"]
517
+ N = len(prompts)
518
+ x_t = tf.random.normal(shape=(N, *latent_shape))
519
+ texts = tf.convert_to_tensor([process_text(p) for p in prompts])
520
+
521
+ for t in reversed(range(T)):
522
+ t_batch = tf.convert_to_tensor([[t]] * N)
523
+ eps_hat = diffuser([x_t, texts, t_batch])
524
+ x_t = tf.convert_to_tensor(
525
+ denoise_step(x_t.numpy(), eps_hat.numpy(), t, a, beta), dtype=tf.float32
526
+ )
527
+
528
+ imgs = decoder(x_t)
529
+ return {"outputs": imgs.numpy().tolist()}