KublaiKhan1 commited on
Commit
6fa0aca
·
verified ·
1 Parent(s): 0561cd8

Upload LinearAE/vqvae.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. LinearAE/vqvae.py +523 -0
LinearAE/vqvae.py ADDED
@@ -0,0 +1,523 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+ import flax.linen as nn
3
+ import jax.numpy as jnp
4
+ import functools
5
+ import ml_collections
6
+ import jax
7
+
8
+ ###########################
9
+ ### Helper Modules
10
+ ### https://github.com/google-research/maskgit/blob/main/maskgit/nets/layers.py
11
+ ###########################
12
+
13
+ def get_norm_layer(norm_type):
14
+ """Normalization layer."""
15
+ if norm_type == 'BN':
16
+ raise NotImplementedError
17
+ elif norm_type == 'LN':
18
+ norm_fn = functools.partial(nn.LayerNorm)
19
+ elif norm_type == 'GN':
20
+ norm_fn = functools.partial(nn.GroupNorm)
21
+ else:
22
+ raise NotImplementedError
23
+ return norm_fn
24
+
25
+
26
+ def tensorflow_style_avg_pooling(x, window_shape, strides, padding: str):
27
+ pool_sum = jax.lax.reduce_window(x, 0.0, jax.lax.add,
28
+ (1,) + window_shape + (1,),
29
+ (1,) + strides + (1,), padding)
30
+ pool_denom = jax.lax.reduce_window(
31
+ jnp.ones_like(x), 0.0, jax.lax.add, (1,) + window_shape + (1,),
32
+ (1,) + strides + (1,), padding)
33
+ return pool_sum / pool_denom
34
+
35
+ def upsample(x, factor=2):
36
+ n, h, w, c = x.shape
37
+ x = jax.image.resize(x, (n, h * factor, w * factor, c), method='nearest')
38
+ return x
39
+
40
+ def dsample(x):
41
+ return tensorflow_style_avg_pooling(x, (2, 2), strides=(2, 2), padding='same')
42
+
43
+ def squared_euclidean_distance(a: jnp.ndarray,
44
+ b: jnp.ndarray,
45
+ b2: jnp.ndarray = None) -> jnp.ndarray:
46
+ """Computes the pairwise squared Euclidean distance.
47
+
48
+ Args:
49
+ a: float32: (n, d): An array of points.
50
+ b: float32: (m, d): An array of points.
51
+ b2: float32: (d, m): b square transpose.
52
+
53
+ Returns:
54
+ d: float32: (n, m): Where d[i, j] is the squared Euclidean distance between
55
+ a[i] and b[j].
56
+ """
57
+ if b2 is None:
58
+ b2 = jnp.sum(b.T**2, axis=0, keepdims=True)
59
+ a2 = jnp.sum(a**2, axis=1, keepdims=True)
60
+ ab = jnp.matmul(a, b.T)
61
+ d = a2 - 2 * ab + b2
62
+ return d
63
+
64
+ def entropy_loss_fn(affinity, loss_type="softmax", temperature=1.0):
65
+ """Calculates the entropy loss. Affinity is the similarity/distance matrix."""
66
+ flat_affinity = affinity.reshape(-1, affinity.shape[-1])
67
+ flat_affinity /= temperature
68
+ probs = jax.nn.softmax(flat_affinity, axis=-1)
69
+ log_probs = jax.nn.log_softmax(flat_affinity + 1e-5, axis=-1)
70
+ if loss_type == "softmax":
71
+ target_probs = probs
72
+ elif loss_type == "argmax":
73
+ codes = jnp.argmax(flat_affinity, axis=-1)
74
+ onehots = jax.nn.one_hot(
75
+ codes, flat_affinity.shape[-1], dtype=flat_affinity.dtype)
76
+ onehots = probs - jax.lax.stop_gradient(probs - onehots)
77
+ target_probs = onehots
78
+ else:
79
+ raise ValueError("Entropy loss {} not supported".format(loss_type))
80
+ avg_probs = jnp.mean(target_probs, axis=0)
81
+ avg_entropy = -jnp.sum(avg_probs * jnp.log(avg_probs + 1e-5))
82
+ sample_entropy = -jnp.mean(jnp.sum(target_probs * log_probs, axis=-1))
83
+ loss = sample_entropy - avg_entropy
84
+ return loss
85
+
86
+ def sg(x):
87
+ return jax.lax.stop_gradient(x)
88
+
89
+
90
+
91
+
92
+ ###########################
93
+ ### Modules
94
+ ###########################
95
+
96
+ class ResBlock(nn.Module):
97
+ """Basic Residual Block."""
98
+ filters: int
99
+ norm_fn: Any
100
+ activation_fn: Any
101
+
102
+ @nn.compact
103
+ def __call__(self, x):
104
+ input_dim = x.shape[-1]
105
+ residual = x
106
+ x = self.norm_fn()(x)
107
+ x = self.activation_fn(x)
108
+ x = nn.Conv(self.filters, kernel_size=(3, 3), use_bias=False)(x)
109
+ x = self.norm_fn()(x)
110
+ x = self.activation_fn(x)
111
+ x = nn.Conv(self.filters, kernel_size=(3, 3), use_bias=False)(x)
112
+
113
+ if input_dim != self.filters:#Basically if input doesn't match output, use a skip
114
+ residual = nn.Conv(self.filters, kernel_size=(1, 1), use_bias=False)(x)
115
+ return x + residual
116
+
117
+ class Fourier(nn.Module):
118
+
119
+ def setup(self):
120
+
121
+ #Our input comes in as 3... after we convert to 512, maybe instead we convert to 256, and then do this?
122
+ self.weight = jax.random.normal(self.make_rng("noise"), means.shape)
123
+
124
+ @nn.compact
125
+ def __call__(self, f):
126
+ #this is probabl ycahnnels lastz
127
+ f = 2 * math.pi * input @ self.weight.T
128
+ return torch.cat([f.cos(), f.sin()], dim = -1)
129
+
130
+ from einops import rearrange
131
+ class Encoder(nn.Module):
132
+
133
+ config: ml_collections.ConfigDict
134
+
135
+ #So in this setup, we don't carea bout anything
136
+ @nn.compact
137
+ def __call__(self, x):
138
+ print("init encoder")
139
+ print("x shape", x.shape)
140
+ x = rearrange(x, '... (h b1) (w b2) c -> ... h w (c b1 b2)', b1=8, b2=8)
141
+ x = nn.Dense(4)(x)#We just put to 4 for now
142
+ print(x.shape)
143
+ return x
144
+ #k = nn.Dense(self.hidden_size, **self.tc.default_config())(x_modulated)
145
+ #1x1 conv, uplift from 3 to like..... 64
146
+ #That gives us 256x256x64
147
+ #Then pixelshuffle to
148
+
149
+
150
+ class OriginalEncoder(nn.Module):
151
+ """From [H,W,D] image to [H',W',D'] embedding. Using Conv layers."""
152
+ config: ml_collections.ConfigDict
153
+
154
+ def setup(self):
155
+ self.filters = self.config.filters#filters is the original setup
156
+ self.num_res_blocks = self.config.num_res_blocks
157
+ self.channel_multipliers = self.config.channel_multipliers
158
+ self.embedding_dim = self.config.embedding_dim
159
+ self.norm_type = self.config.norm_type
160
+ self.activation_fn = nn.swish
161
+
162
+ @nn.compact
163
+ def __call__(self, x):
164
+ print("Initializing encoder.")
165
+ norm_fn = get_norm_layer(norm_type=self.norm_type)
166
+ block_args = dict(norm_fn=norm_fn, activation_fn=self.activation_fn)
167
+ print("Incoming encoder shape", x.shape)
168
+ x = nn.Conv(self.filters, kernel_size=(3, 3), use_bias=False)(x)
169
+ print('Encoder layer', x.shape)
170
+ num_blocks = len(self.channel_multipliers)
171
+
172
+ #The way SD works, is it does 2x resnet, not changing anything, then downsample
173
+ #It does this 3 times, leading to 8x downsample
174
+ #Then it has an extra resnet block, and THEN from 512 to 8 / 4
175
+
176
+ for i in range(num_blocks):
177
+ filters = self.filters * self.channel_multipliers[i]
178
+ for _ in range(self.num_res_blocks):
179
+ x = ResBlock(filters, **block_args)(x)
180
+ if i < num_blocks - 1:#For each block *except end* do downsample
181
+ print("doing downsample")
182
+ x = dsample(x)
183
+ print('Encoder layer', x.shape)
184
+
185
+ #After we are done downsampling, we do the 2 resnet, and down below here, we have the 2 midblock?
186
+
187
+ for _ in range(self.num_res_blocks):
188
+ x = ResBlock(filters, **block_args)(x)
189
+ print('Encoder layer final', x.shape)
190
+
191
+ x = norm_fn()(x)
192
+ x = self.activation_fn(x)
193
+ last_dim = self.embedding_dim*2 if self.config['quantizer_type'] == 'kl' else self.embedding_dim
194
+ x = nn.Conv(last_dim, kernel_size=(1, 1))(x)
195
+ print("Final embeddings are size", x.shape)
196
+ return x
197
+
198
+ class Decoder(nn.Module):
199
+ """From [H',W',D'] embedding to [H,W,D] embedding. Using Conv layers."""
200
+
201
+ config: ml_collections.ConfigDict
202
+
203
+ def setup(self):
204
+ self.filters = self.config.filters
205
+ self.num_res_blocks = self.config.num_res_blocks
206
+ self.channel_multipliers = self.config.channel_multipliers
207
+ self.norm_type = self.config.norm_type
208
+ self.image_channels = self.config.image_channels
209
+ self.activation_fn = nn.swish
210
+
211
+ @nn.compact
212
+ def __call__(self, x):
213
+ norm_fn = get_norm_layer(norm_type=self.norm_type)
214
+ block_args = dict(norm_fn=norm_fn, activation_fn=self.activation_fn,)
215
+ num_blocks = len(self.channel_multipliers)
216
+ filters = self.filters * self.channel_multipliers[-1]
217
+ print("Decoder incoming shape", x.shape)
218
+
219
+ #We don't need to do anything here because it'll put it back to 512
220
+
221
+ x = nn.Conv(filters, kernel_size=(3, 3), use_bias=True)(x)
222
+ print("Decoder input", x.shape)
223
+
224
+
225
+ #This is the mid block
226
+ for _ in range(self.num_res_blocks):
227
+ x = ResBlock(filters, **block_args)(x)
228
+ print('Mid Block Decoder layer', x.shape)
229
+
230
+ #First two SET of blocks is just 3 resnet, no channel changes, we are already at 4x = 512
231
+
232
+ for i in reversed(range(num_blocks)):
233
+ filters = self.filters * self.channel_multipliers[i]
234
+ for _ in range(self.num_res_blocks):#sym
235
+ x = ResBlock(filters, **block_args)(x)
236
+ if i > 0:
237
+ x = upsample(x, 2)
238
+ x = nn.Conv(filters, kernel_size=(3, 3))(x)
239
+ print('Decoder layer', x.shape)
240
+ x = norm_fn()(x)
241
+ x = self.activation_fn(x)
242
+ x = nn.Conv(self.image_channels, kernel_size=(3, 3))(x)
243
+ return x
244
+
245
+ class VectorQuantizer(nn.Module):
246
+ """Basic vector quantizer."""
247
+ config: ml_collections.ConfigDict
248
+ train: bool
249
+
250
+ @nn.compact
251
+ def __call__(self, x):
252
+ codebook_size = self.config.codebook_size
253
+ emb_dim = x.shape[-1]
254
+ codebook = self.param(
255
+ "codebook",
256
+ jax.nn.initializers.variance_scaling(scale=1.0, mode="fan_in", distribution="uniform"),
257
+ (codebook_size, emb_dim))
258
+ codebook = jnp.asarray(codebook) # (codebook_size, emb_dim)
259
+ distances = jnp.reshape(
260
+ squared_euclidean_distance(jnp.reshape(x, (-1, emb_dim)), codebook),
261
+ x.shape[:-1] + (codebook_size,)) # [x, codebook_size] similarity matrix.
262
+ encoding_indices = jnp.argmin(distances, axis=-1)
263
+ encoding_onehot = jax.nn.one_hot(encoding_indices, codebook_size)
264
+ quantized = self.quantize(encoding_onehot)
265
+ result_dict = dict()
266
+ if self.train:
267
+ e_latent_loss = jnp.mean((sg(quantized) - x)**2) * self.config.commitment_cost
268
+ q_latent_loss = jnp.mean((quantized - sg(x))**2)
269
+ entropy_loss = 0.0
270
+ if self.config.entropy_loss_ratio != 0:
271
+ entropy_loss = entropy_loss_fn(
272
+ -distances,
273
+ loss_type=self.config.entropy_loss_type,
274
+ temperature=self.config.entropy_temperature
275
+ ) * self.config.entropy_loss_ratio
276
+ e_latent_loss = jnp.asarray(e_latent_loss, jnp.float32)
277
+ q_latent_loss = jnp.asarray(q_latent_loss, jnp.float32)
278
+ entropy_loss = jnp.asarray(entropy_loss, jnp.float32)
279
+ loss = e_latent_loss + q_latent_loss + entropy_loss
280
+ result_dict = dict(
281
+ quantizer_loss=loss,
282
+ e_latent_loss=e_latent_loss,
283
+ q_latent_loss=q_latent_loss,
284
+ entropy_loss=entropy_loss)
285
+ quantized = x + jax.lax.stop_gradient(quantized - x)
286
+
287
+ result_dict.update({
288
+ "z_ids": encoding_indices,
289
+ })
290
+ return quantized, result_dict
291
+
292
+ def quantize(self, encoding_onehot: jnp.ndarray) -> jnp.ndarray:
293
+ codebook = jnp.asarray(self.variables["params"]["codebook"])
294
+ return jnp.dot(encoding_onehot, codebook)
295
+
296
+ def decode_ids(self, ids: jnp.ndarray) -> jnp.ndarray:
297
+ codebook = self.variables["params"]["codebook"]
298
+ return jnp.take(codebook, ids, axis=0)
299
+
300
+ class KLQuantizer(nn.Module):
301
+ config: ml_collections.ConfigDict
302
+ train: bool
303
+
304
+ @nn.compact
305
+ def __call__(self, x):
306
+ emb_dim = x.shape[-1] // 2 # Use half as means, half as logvars.
307
+ means = x[..., :emb_dim]
308
+ logvars = x[..., emb_dim:]
309
+ if not self.train:
310
+ result_dict = dict()
311
+ result_dict["std"] = jnp.exp(0.5 * logvars)
312
+ return means, result_dict
313
+ else:
314
+ noise = jax.random.normal(self.make_rng("noise"), means.shape)
315
+ stds = jnp.exp(0.5 * logvars)
316
+ z = means + stds * noise
317
+ #kl_loss = -0.5 * jnp.mean(1 + logvars - means**2 - jnp.exp(logvars))
318
+
319
+ #New kl
320
+ kl_loss = - 0.5 * jnp.sum(1 + logvars - jnp.square(means) - jnp.exp(logvars),axis=tuple(range(1, means.ndim)))
321
+ kl_loss = jnp.mean(kl_loss)
322
+
323
+ result_dict = dict(quantizer_loss=kl_loss)
324
+ result_dict["std"] = jnp.exp(0.5 * logvars)
325
+ return z, result_dict
326
+
327
+ class AEQuantizer(nn.Module): #cooking
328
+ config: ml_collections.ConfigDict
329
+ train: bool
330
+
331
+ @nn.compact
332
+ def __call__(self, x):
333
+ result_dict = dict()
334
+ result_dict["std"] = 0.0
335
+ return x, result_dict
336
+
337
+ import jax
338
+ import jax.numpy as jnp
339
+ from jax import random
340
+
341
+ def imq_kernel(X: jnp.ndarray, Y: jnp.ndarray, h_dim: int):
342
+ batch_size = X.shape[0]
343
+
344
+ norms_x = jnp.sum(X**2, axis=1, keepdims=True) # batch_size x 1
345
+ prods_x = jnp.dot(X, X.T) # batch_size x batch_size
346
+ dists_x = norms_x + norms_x.T - 2 * prods_x
347
+
348
+ norms_y = jnp.sum(Y**2, axis=1, keepdims=True) # batch_size x 1
349
+ prods_y = jnp.dot(Y, Y.T) # batch_size x batch_size
350
+ dists_y = norms_y + norms_y.T - 2 * prods_y
351
+
352
+ dot_prd = jnp.dot(X, Y.T)
353
+ dists_c = norms_x + norms_y.T - 2 * dot_prd
354
+
355
+ stats = 0
356
+ for scale in [0.1, 0.2, 0.5, 1.0, 2.0, 5.0, 10.0]:
357
+ C = 2 * h_dim * 1.0 * scale
358
+ res1 = C / (C + dists_x)
359
+ res1 += C / (C + dists_y)
360
+
361
+ res1 = (1 - jnp.eye(batch_size)) * res1
362
+ res1 = jnp.sum(res1) / (batch_size - 1)
363
+
364
+ res2 = C / (C + dists_c)
365
+ res2 = jnp.sum(res2) * 2.0 / batch_size
366
+ stats += res1 - res2
367
+
368
+ return stats
369
+
370
+ class MMDQuantizer(nn.Module): #cooking
371
+ config: ml_collections.ConfigDict
372
+ train: bool
373
+
374
+ @nn.compact
375
+ def __call__(self, x):
376
+ if not self.train:
377
+ result_dict = dict()
378
+ return x, result_dict
379
+ else:
380
+ print("mmd quantizer")
381
+ batch_size, height, width, latent_channels = x.shape
382
+ z_flat = x.reshape(batch_size, -1)
383
+ print(z_flat.shape)
384
+ z_fake_flat = jax.random.normal(self.make_rng("noise"), z_flat.shape) * self.config["MMD_weight"]
385
+ print(z_fake_flat.shape)
386
+ mmd_loss = imq_kernel(z_flat, z_fake_flat, z_flat.shape[1])
387
+ print(mmd_loss.shape)
388
+ print(mmd_loss)
389
+ result_dict = dict(quantizer_loss=mmd_loss)
390
+ return x, result_dict
391
+
392
+
393
+
394
+ class KLQuantizerTwo(nn.Module):
395
+ config: ml_collections.ConfigDict
396
+ train: bool
397
+
398
+ @nn.compact
399
+ def __call__(self, x):
400
+ #emb_dim = x.shape[-1] // 2 # Use half as means, half as logvars.
401
+ #means = x[..., :emb_dim]
402
+ #logvars = x[..., emb_dim:]
403
+
404
+ #Wwe actually wanna do mean and STD on the batch axis?
405
+
406
+
407
+ #we start as b hw 8, go to b hw 4, with mean and std over those.
408
+
409
+ if not self.train:
410
+ result_dict = dict()
411
+ result_dict["std"] = 1.0
412
+ return x, result_dict
413
+ else:
414
+ stds = jnp.std(x, axis = [1,2,3])
415
+
416
+ noise = jax.random.normal(self.make_rng("noise"), x.shape)
417
+
418
+ logvars = .5 * jnp.log(stds)
419
+ logvars = logvars.reshape(-1,1,1,1)
420
+ if True:#This is true for special KL where we set sigma to 1 manually
421
+ logvars = 0.0
422
+
423
+
424
+ if False:#dinossl
425
+ x_2 = x.reshape(x.shape[0], -1, x.shape[-1])#Linear with channel size
426
+ x_2 = jnp.swapaxes(x_2,0,1)
427
+ #then/ get the covariance
428
+ cov = jnp.swapaxes(x_2,1,2) @ x_2 / x.shape[0]
429
+ #Not sure about this, we also have regular cov
430
+ I_d = jnp.identity(x.shape[-1])
431
+ R_eps = jnp.log(jnp.linalg.det(jnp.expand_dims(I_d, axis = 0) + x.shape[-1]/ (.0001 ** 2) * cov))
432
+
433
+ #So something here *does* depend on the -1 shape, but I need to math it out.
434
+ kl_loss = R_eps.mean()
435
+
436
+
437
+ #This is the denoising version
438
+ kl_loss = - 0.5 * jnp.sum(1 + logvars - jnp.square(x) - jnp.exp(logvars),axis=tuple(range(1, x.ndim)))
439
+ kl_loss = jnp.mean(kl_loss)
440
+
441
+ result_dict = dict(quantizer_loss=kl_loss)
442
+ result_dict["std"] = 1.0
443
+
444
+ #For proper kl two, we need to return noise + mean.
445
+ return x + noise, result_dict
446
+
447
+
448
+ class FSQuantizer(nn.Module):
449
+ config: ml_collections.ConfigDict
450
+ train: bool
451
+
452
+ @nn.compact
453
+ def __call__(self, x):
454
+ assert self.config['fsq_levels'] % 2 == 1, "FSQ levels must be odd."
455
+ z = jnp.tanh(x) # [-1, 1]
456
+ z = z * (self.config['fsq_levels']-1) / 2 # [-fsq_levels/2, fsq_levels/2]
457
+ zhat = jnp.round(z) # e.g. [-2, -1, 0, 1, 2]
458
+ quantized = z + jax.lax.stop_gradient(zhat - z)
459
+ quantized = quantized / (self.config['fsq_levels'] // 2) # [-1, 1], but quantized.
460
+ result_dict = dict()
461
+
462
+ # Diagnostics for codebook usage.
463
+ zhat_scaled = zhat + self.config['fsq_levels'] // 2
464
+ basis = jnp.concatenate((jnp.array([1]), jnp.cumprod(jnp.array([self.config['fsq_levels']] * (x.shape[-1]-1))))).astype(jnp.uint32)
465
+ idx = (zhat_scaled * basis).sum(axis=-1).astype(jnp.uint32)
466
+ idx_flat = idx.reshape(-1)
467
+ usage = jnp.bincount(idx_flat, length=self.config['fsq_levels']**x.shape[-1])
468
+
469
+ result_dict.update({
470
+ "z_ids": zhat,
471
+ 'usage': usage
472
+ })
473
+ return quantized, result_dict
474
+
475
+ class VQVAE(nn.Module):
476
+ """VQVAE model."""
477
+ config: ml_collections.ConfigDict
478
+ train: bool
479
+
480
+ def setup(self):
481
+ """VQVAE setup."""
482
+ if self.config['quantizer_type'] == 'vq':
483
+ self.quantizer = VectorQuantizer(config=self.config, train=self.train)
484
+ elif self.config['quantizer_type'] == 'kl':
485
+ self.quantizer = KLQuantizer(config=self.config, train=self.train)
486
+ elif self.config['quantizer_type'] == 'fsq':
487
+ self.quantizer = FSQuantizer(config=self.config, train=self.train)
488
+ elif self.config['quantizer_type'] == 'ae':
489
+ self.quantizer = AEQuantizer(config=self.config, train=self.train)
490
+ elif self.config["quantizer_type"] == "kl_two":
491
+ self.quantizer = KLQuantizerTwo(config=self.config, train=self.train)
492
+ self.encoder = Encoder(config=self.config)
493
+ self.decoder = Decoder(config=self.config)
494
+
495
+ def encode(self, image):
496
+ encoded_feature = self.encoder(image)
497
+ quantized, result_dict = self.quantizer(encoded_feature)
498
+ print("After quant", quantized.shape)
499
+ return quantized, result_dict
500
+
501
+ def decode(self, z_vectors):
502
+ print("z_vectors shape", z_vectors.shape)
503
+ reconstructed = self.decoder(z_vectors)
504
+ return reconstructed
505
+
506
+ def decode_from_indices(self, z_ids):
507
+ z_vectors = self.quantizer.decode_ids(z_ids)
508
+ reconstructed_image = self.decode(z_vectors)
509
+ return reconstructed_image
510
+
511
+ def encode_to_indices(self, image):
512
+ encoded_feature = self.encoder(image)
513
+ _, result_dict = self.quantizer(encoded_feature)
514
+ ids = result_dict["z_ids"]
515
+ return ids
516
+
517
+ def __call__(self, input_dict):
518
+ quantized, result_dict = jax.lax.stop_gradient(self.encode(input_dict))
519
+ #Freezing encoder now
520
+ print("encode finished")
521
+ result_dict["latents"] = quantized
522
+ outputs = self.decoder(quantized)
523
+ return outputs, result_dict