jcopo commited on
Commit
aa4dc29
·
verified ·
1 Parent(s): 0a02da1

Upload model at step 45000

Browse files
Files changed (1) hide show
  1. config.py +15 -2031
config.py CHANGED
@@ -1,2042 +1,26 @@
1
  """Model configuration for jcopo/mnist
2
 
3
- This file contains the model architecture definition.
4
  Training step: 45000
5
  Precision: float32
6
  """
7
 
8
- from triax.models.nn.condUNet import CondUNet2D
9
  import jax.numpy as jnp
10
  from flax import nnx
 
11
 
12
- # Model architecture
13
  model = CondUNet2D(
14
- blocks_down=[TimestepEmbedSequential( # Param: 320 (1.3 KB)
15
- layers=[Conv( # Param: 320 (1.3 KB)
16
- bias=Param( # 32 (128 B)
17
- value=Array(shape=(32,), dtype=dtype('float32'))
18
- ),
19
- bias_init=<function zeros at 0x7fb32b98c2c0>,
20
- conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
21
- dtype=float32,
22
- feature_group_count=1,
23
- in_features=1,
24
- input_dilation=1,
25
- kernel=Param( # 288 (1.2 KB)
26
- value=Array(shape=(3, 3, 1, 32), dtype=dtype('float32'))
27
- ),
28
- kernel_dilation=1,
29
- kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
30
- kernel_shape=(3, 3, 1, 64),
31
- kernel_size=(3, 3),
32
- mask=None,
33
- out_features=64,
34
- padding=(1, 1),
35
- param_dtype=float32,
36
- precision=None,
37
- promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
38
- strides=(1, 1),
39
- use_bias=True
40
- )]
41
- ), TimestepEmbedSequential( # Param: 26,880 (107.5 KB), RngState: 2 (12 B), Total: 26,882 (107.5 KB)
42
- layers=[ResnetBlock( # Param: 26,880 (107.5 KB), RngState: 2 (12 B), Total: 26,882 (107.5 KB)
43
- activation=<PjitFunction of <function silu at 0x7fb32b8440e0>>,
44
- conv1=Conv( # Param: 9,248 (37.0 KB)
45
- bias=Param( # 32 (128 B)
46
- value=Array(shape=(32,), dtype=dtype('float32'))
47
- ),
48
- bias_init=<function zeros at 0x7fb32b98c2c0>,
49
- conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
50
- dtype=float32,
51
- feature_group_count=1,
52
- in_features=64,
53
- input_dilation=1,
54
- kernel=Param( # 9,216 (36.9 KB)
55
- value=Array(shape=(3, 3, 32, 32), dtype=dtype('float32'))
56
- ),
57
- kernel_dilation=1,
58
- kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
59
- kernel_shape=(3, 3, 64, 64),
60
- kernel_size=(3, 3),
61
- mask=None,
62
- out_features=64,
63
- padding=(1, 1),
64
- param_dtype=float32,
65
- precision=None,
66
- promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
67
- strides=(1, 1),
68
- use_bias=True
69
- ),
70
- conv2=Conv( # Param: 9,248 (37.0 KB)
71
- bias=Param( # 32 (128 B)
72
- value=Array(shape=(32,), dtype=dtype('float32'))
73
- ),
74
- bias_init=<function zeros at 0x7fb32b98c2c0>,
75
- conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
76
- dtype=float32,
77
- feature_group_count=1,
78
- in_features=64,
79
- input_dilation=1,
80
- kernel=Param( # 9,216 (36.9 KB)
81
- value=Array(shape=(3, 3, 32, 32), dtype=dtype('float32'))
82
- ),
83
- kernel_dilation=1,
84
- kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
85
- kernel_shape=(3, 3, 64, 64),
86
- kernel_size=(3, 3),
87
- mask=None,
88
- out_features=64,
89
- padding=(1, 1),
90
- param_dtype=float32,
91
- precision=None,
92
- promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
93
- strides=(1, 1),
94
- use_bias=True
95
- ),
96
- dropout=Dropout( # RngState: 2 (12 B)
97
- broadcast_dims=(),
98
- deterministic=False,
99
- rate=0.1,
100
- rng_collection='dropout',
101
- rngs=RngStream( # RngState: 2 (12 B)
102
- count=RngCount( # 1 (4 B)
103
- value=Array(0, dtype=uint32),
104
- tag='default'
105
- ),
106
- key=RngKey( # 1 (8 B)
107
- value=Array((), dtype=key<fry>) overlaying:
108
- [2585633080 2083471411],
109
- tag='default'
110
- ),
111
- tag='default'
112
- )
113
- ),
114
- embedding_dim=256,
115
- in_channels=64,
116
- norm1=GroupNorm( # Param: 64 (256 B)
117
- axis_index_groups=None,
118
- axis_name=None,
119
- bias=Param( # 32 (128 B)
120
- value=Array(shape=(32,), dtype=dtype('float32'))
121
- ),
122
- bias_init=<function zeros at 0x7fb32b98c2c0>,
123
- dtype=float32,
124
- epsilon=1e-06,
125
- feature_axis=-1,
126
- group_size=2,
127
- num_groups=32,
128
- param_dtype=float32,
129
- reduction_axes=None,
130
- scale=Param( # 32 (128 B)
131
- value=Array(shape=(32,), dtype=dtype('float32'))
132
- ),
133
- scale_init=<function ones at 0x7fb32b86e520>,
134
- use_bias=True,
135
- use_fast_variance=True,
136
- use_scale=True
137
- ),
138
- norm2=GroupNorm( # Param: 64 (256 B)
139
- axis_index_groups=None,
140
- axis_name=None,
141
- bias=Param( # 32 (128 B)
142
- value=Array(shape=(32,), dtype=dtype('float32'))
143
- ),
144
- bias_init=<function zeros at 0x7fb32b98c2c0>,
145
- dtype=float32,
146
- epsilon=1e-06,
147
- feature_axis=-1,
148
- group_size=2,
149
- num_groups=32,
150
- param_dtype=float32,
151
- reduction_axes=None,
152
- scale=Param( # 32 (128 B)
153
- value=Array(shape=(32,), dtype=dtype('float32'))
154
- ),
155
- scale_init=<function ones at 0x7fb32b86e520>,
156
- use_bias=True,
157
- use_fast_variance=True,
158
- use_scale=True
159
- ),
160
- out_channels=64,
161
- time_mlp=Linear( # Param: 8,256 (33.0 KB)
162
- bias=Param( # 64 (256 B)
163
- value=Array(shape=(64,), dtype=dtype('float32'))
164
- ),
165
- bias_init=<function zeros at 0x7fb32b98c2c0>,
166
- dot_general=<function dot_general at 0x7fb32c21a3e0>,
167
- dtype=float32,
168
- in_features=256,
169
- kernel=Param( # 8,192 (32.8 KB)
170
- value=Array(shape=(128, 64), dtype=dtype('float32'))
171
- ),
172
- kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
173
- out_features=128,
174
- param_dtype=float32,
175
- precision=None,
176
- promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
177
- use_bias=True
178
  )
179
- )]
180
- ), TimestepEmbedSequential( # Param: 9,248 (37.0 KB)
181
- layers=[Downsample( # Param: 9,248 (37.0 KB)
182
- conv=Conv( # Param: 9,248 (37.0 KB)
183
- bias=Param( # 32 (128 B)
184
- value=Array(shape=(32,), dtype=dtype('float32'))
185
- ),
186
- bias_init=<function zeros at 0x7fb32b98c2c0>,
187
- conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
188
- dtype=float32,
189
- feature_group_count=1,
190
- in_features=64,
191
- input_dilation=1,
192
- kernel=Param( # 9,216 (36.9 KB)
193
- value=Array(shape=(3, 3, 32, 32), dtype=dtype('float32'))
194
- ),
195
- kernel_dilation=1,
196
- kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
197
- kernel_shape=(3, 3, 64, 64),
198
- kernel_size=(3, 3),
199
- mask=None,
200
- out_features=64,
201
- padding=(0, 0),
202
- param_dtype=float32,
203
- precision=None,
204
- promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
205
- strides=(2, 2),
206
- use_bias=True
207
- ),
208
- method='conv'
209
- )]
210
- ), TimestepEmbedSequential( # Param: 91,008 (364.0 KB), RngState: 2 (12 B), Total: 91,010 (364.0 KB)
211
- layers=[ResnetBlock( # Param: 74,240 (297.0 KB), RngState: 2 (12 B), Total: 74,242 (297.0 KB)
212
- activation=<PjitFunction of <function silu at 0x7fb32b8440e0>>,
213
- conv1=Conv( # Param: 18,496 (74.0 KB)
214
- bias=Param( # 64 (256 B)
215
- value=Array(shape=(64,), dtype=dtype('float32'))
216
- ),
217
- bias_init=<function zeros at 0x7fb32b98c2c0>,
218
- conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
219
- dtype=float32,
220
- feature_group_count=1,
221
- in_features=64,
222
- input_dilation=1,
223
- kernel=Param( # 18,432 (73.7 KB)
224
- value=Array(shape=(3, 3, 32, 64), dtype=dtype('float32'))
225
- ),
226
- kernel_dilation=1,
227
- kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
228
- kernel_shape=(3, 3, 64, 128),
229
- kernel_size=(3, 3),
230
- mask=None,
231
- out_features=128,
232
- padding=(1, 1),
233
- param_dtype=float32,
234
- precision=None,
235
- promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
236
- strides=(1, 1),
237
- use_bias=True
238
- ),
239
- conv2=Conv( # Param: 36,928 (147.7 KB)
240
- bias=Param( # 64 (256 B)
241
- value=Array(shape=(64,), dtype=dtype('float32'))
242
- ),
243
- bias_init=<function zeros at 0x7fb32b98c2c0>,
244
- conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
245
- dtype=float32,
246
- feature_group_count=1,
247
- in_features=128,
248
- input_dilation=1,
249
- kernel=Param( # 36,864 (147.5 KB)
250
- value=Array(shape=(3, 3, 64, 64), dtype=dtype('float32'))
251
- ),
252
- kernel_dilation=1,
253
- kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
254
- kernel_shape=(3, 3, 128, 128),
255
- kernel_size=(3, 3),
256
- mask=None,
257
- out_features=128,
258
- padding=(1, 1),
259
- param_dtype=float32,
260
- precision=None,
261
- promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
262
- strides=(1, 1),
263
- use_bias=True
264
- ),
265
- dropout=Dropout( # RngState: 2 (12 B)
266
- broadcast_dims=(),
267
- deterministic=False,
268
- rate=0.1,
269
- rng_collection='dropout',
270
- rngs=RngStream( # RngState: 2 (12 B)
271
- count=RngCount( # 1 (4 B)
272
- value=Array(0, dtype=uint32),
273
- tag='default'
274
- ),
275
- key=RngKey( # 1 (8 B)
276
- value=Array((), dtype=key<fry>) overlaying:
277
- [2656139193 2766658851],
278
- tag='default'
279
- ),
280
- tag='default'
281
- )
282
- ),
283
- embedding_dim=256,
284
- in_channels=64,
285
- nin_shortcut=Conv( # Param: 2,112 (8.4 KB)
286
- bias=Param( # 64 (256 B)
287
- value=Array(shape=(64,), dtype=dtype('float32'))
288
- ),
289
- bias_init=<function zeros at 0x7fb32b98c2c0>,
290
- conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
291
- dtype=float32,
292
- feature_group_count=1,
293
- in_features=64,
294
- input_dilation=1,
295
- kernel=Param( # 2,048 (8.2 KB)
296
- value=Array(shape=(1, 1, 32, 64), dtype=dtype('float32'))
297
- ),
298
- kernel_dilation=1,
299
- kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
300
- kernel_shape=(1, 1, 64, 128),
301
- kernel_size=(1, 1),
302
- mask=None,
303
- out_features=128,
304
- padding=(0, 0),
305
- param_dtype=float32,
306
- precision=None,
307
- promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
308
- strides=(1, 1),
309
- use_bias=True
310
- ),
311
- norm1=GroupNorm( # Param: 64 (256 B)
312
- axis_index_groups=None,
313
- axis_name=None,
314
- bias=Param( # 32 (128 B)
315
- value=Array(shape=(32,), dtype=dtype('float32'))
316
- ),
317
- bias_init=<function zeros at 0x7fb32b98c2c0>,
318
- dtype=float32,
319
- epsilon=1e-06,
320
- feature_axis=-1,
321
- group_size=2,
322
- num_groups=32,
323
- param_dtype=float32,
324
- reduction_axes=None,
325
- scale=Param( # 32 (128 B)
326
- value=Array(shape=(32,), dtype=dtype('float32'))
327
- ),
328
- scale_init=<function ones at 0x7fb32b86e520>,
329
- use_bias=True,
330
- use_fast_variance=True,
331
- use_scale=True
332
- ),
333
- norm2=GroupNorm( # Param: 128 (512 B)
334
- axis_index_groups=None,
335
- axis_name=None,
336
- bias=Param( # 64 (256 B)
337
- value=Array(shape=(64,), dtype=dtype('float32'))
338
- ),
339
- bias_init=<function zeros at 0x7fb32b98c2c0>,
340
- dtype=float32,
341
- epsilon=1e-06,
342
- feature_axis=-1,
343
- group_size=4,
344
- num_groups=32,
345
- param_dtype=float32,
346
- reduction_axes=None,
347
- scale=Param( # 64 (256 B)
348
- value=Array(shape=(64,), dtype=dtype('float32'))
349
- ),
350
- scale_init=<function ones at 0x7fb32b86e520>,
351
- use_bias=True,
352
- use_fast_variance=True,
353
- use_scale=True
354
- ),
355
- out_channels=128,
356
- time_mlp=Linear( # Param: 16,512 (66.0 KB)
357
- bias=Param( # 128 (512 B)
358
- value=Array(shape=(128,), dtype=dtype('float32'))
359
- ),
360
- bias_init=<function zeros at 0x7fb32b98c2c0>,
361
- dot_general=<function dot_general at 0x7fb32c21a3e0>,
362
- dtype=float32,
363
- in_features=256,
364
- kernel=Param( # 16,384 (65.5 KB)
365
- value=Array(shape=(128, 128), dtype=dtype('float32'))
366
- ),
367
- kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
368
- out_features=256,
369
- param_dtype=float32,
370
- precision=None,
371
- promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
372
- use_bias=True
373
- )
374
- ), AttnBlock( # Param: 16,768 (67.1 KB)
375
- dtype=float32,
376
- head_dim=32,
377
- k=Conv( # Param: 4,160 (16.6 KB)
378
- bias=Param( # 64 (256 B)
379
- value=Array(shape=(64,), dtype=dtype('float32'))
380
- ),
381
- bias_init=<function zeros at 0x7fb32b98c2c0>,
382
- conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
383
- dtype=float32,
384
- feature_group_count=1,
385
- in_features=128,
386
- input_dilation=1,
387
- kernel=Param( # 4,096 (16.4 KB)
388
- value=Array(shape=(1, 1, 64, 64), dtype=dtype('float32'))
389
- ),
390
- kernel_dilation=1,
391
- kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
392
- kernel_shape=(1, 1, 128, 128),
393
- kernel_size=(1, 1),
394
- mask=None,
395
- out_features=128,
396
- padding='SAME',
397
- param_dtype=float32,
398
- precision=None,
399
- promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
400
- strides=1,
401
- use_bias=True
402
- ),
403
- norm=GroupNorm( # Param: 128 (512 B)
404
- axis_index_groups=None,
405
- axis_name=None,
406
- bias=Param( # 64 (256 B)
407
- value=Array(shape=(64,), dtype=dtype('float32'))
408
- ),
409
- bias_init=<function zeros at 0x7fb32b98c2c0>,
410
- dtype=float32,
411
- epsilon=1e-06,
412
- feature_axis=-1,
413
- group_size=4,
414
- num_groups=32,
415
- param_dtype=float32,
416
- reduction_axes=None,
417
- scale=Param( # 64 (256 B)
418
- value=Array(shape=(64,), dtype=dtype('float32'))
419
- ),
420
- scale_init=<function ones at 0x7fb32b86e520>,
421
- use_bias=True,
422
- use_fast_variance=True,
423
- use_scale=True
424
- ),
425
- num_heads=4,
426
- proj_out=Conv( # Param: 4,160 (16.6 KB)
427
- bias=Param( # 64 (256 B)
428
- value=Array(shape=(64,), dtype=dtype('float32'))
429
- ),
430
- bias_init=<function zeros at 0x7fb32b98c2c0>,
431
- conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
432
- dtype=float32,
433
- feature_group_count=1,
434
- in_features=128,
435
- input_dilation=1,
436
- kernel=Param( # 4,096 (16.4 KB)
437
- value=Array(shape=(1, 1, 64, 64), dtype=dtype('float32'))
438
- ),
439
- kernel_dilation=1,
440
- kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
441
- kernel_shape=(1, 1, 128, 128),
442
- kernel_size=(1, 1),
443
- mask=None,
444
- out_features=128,
445
- padding='SAME',
446
- param_dtype=float32,
447
- precision=None,
448
- promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
449
- strides=1,
450
- use_bias=True
451
- ),
452
- q=Conv( # Param: 4,160 (16.6 KB)
453
- bias=Param( # 64 (256 B)
454
- value=Array(shape=(64,), dtype=dtype('float32'))
455
- ),
456
- bias_init=<function zeros at 0x7fb32b98c2c0>,
457
- conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
458
- dtype=float32,
459
- feature_group_count=1,
460
- in_features=128,
461
- input_dilation=1,
462
- kernel=Param( # 4,096 (16.4 KB)
463
- value=Array(shape=(1, 1, 64, 64), dtype=dtype('float32'))
464
- ),
465
- kernel_dilation=1,
466
- kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
467
- kernel_shape=(1, 1, 128, 128),
468
- kernel_size=(1, 1),
469
- mask=None,
470
- out_features=128,
471
- padding='SAME',
472
- param_dtype=float32,
473
- precision=None,
474
- promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
475
- strides=1,
476
- use_bias=True
477
- ),
478
- v=Conv( # Param: 4,160 (16.6 KB)
479
- bias=Param( # 64 (256 B)
480
- value=Array(shape=(64,), dtype=dtype('float32'))
481
- ),
482
- bias_init=<function zeros at 0x7fb32b98c2c0>,
483
- conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
484
- dtype=float32,
485
- feature_group_count=1,
486
- in_features=128,
487
- input_dilation=1,
488
- kernel=Param( # 4,096 (16.4 KB)
489
- value=Array(shape=(1, 1, 64, 64), dtype=dtype('float32'))
490
- ),
491
- kernel_dilation=1,
492
- kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
493
- kernel_shape=(1, 1, 128, 128),
494
- kernel_size=(1, 1),
495
- mask=None,
496
- out_features=128,
497
- padding='SAME',
498
- param_dtype=float32,
499
- precision=None,
500
- promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
501
- strides=1,
502
- use_bias=True
503
- )
504
- )]
505
- ), TimestepEmbedSequential( # Param: 36,928 (147.7 KB)
506
- layers=[Downsample( # Param: 36,928 (147.7 KB)
507
- conv=Conv( # Param: 36,928 (147.7 KB)
508
- bias=Param( # 64 (256 B)
509
- value=Array(shape=(64,), dtype=dtype('float32'))
510
- ),
511
- bias_init=<function zeros at 0x7fb32b98c2c0>,
512
- conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
513
- dtype=float32,
514
- feature_group_count=1,
515
- in_features=128,
516
- input_dilation=1,
517
- kernel=Param( # 36,864 (147.5 KB)
518
- value=Array(shape=(3, 3, 64, 64), dtype=dtype('float32'))
519
- ),
520
- kernel_dilation=1,
521
- kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
522
- kernel_shape=(3, 3, 128, 128),
523
- kernel_size=(3, 3),
524
- mask=None,
525
- out_features=128,
526
- padding=(0, 0),
527
- param_dtype=float32,
528
- precision=None,
529
- promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
530
- strides=(2, 2),
531
- use_bias=True
532
- ),
533
- method='conv'
534
- )]
535
- ), TimestepEmbedSequential( # Param: 90,624 (362.5 KB), RngState: 2 (12 B), Total: 90,626 (362.5 KB)
536
- layers=[ResnetBlock( # Param: 90,624 (362.5 KB), RngState: 2 (12 B), Total: 90,626 (362.5 KB)
537
- activation=<PjitFunction of <function silu at 0x7fb32b8440e0>>,
538
- conv1=Conv( # Param: 36,928 (147.7 KB)
539
- bias=Param( # 64 (256 B)
540
- value=Array(shape=(64,), dtype=dtype('float32'))
541
- ),
542
- bias_init=<function zeros at 0x7fb32b98c2c0>,
543
- conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
544
- dtype=float32,
545
- feature_group_count=1,
546
- in_features=128,
547
- input_dilation=1,
548
- kernel=Param( # 36,864 (147.5 KB)
549
- value=Array(shape=(3, 3, 64, 64), dtype=dtype('float32'))
550
- ),
551
- kernel_dilation=1,
552
- kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
553
- kernel_shape=(3, 3, 128, 128),
554
- kernel_size=(3, 3),
555
- mask=None,
556
- out_features=128,
557
- padding=(1, 1),
558
- param_dtype=float32,
559
- precision=None,
560
- promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
561
- strides=(1, 1),
562
- use_bias=True
563
- ),
564
- conv2=Conv( # Param: 36,928 (147.7 KB)
565
- bias=Param( # 64 (256 B)
566
- value=Array(shape=(64,), dtype=dtype('float32'))
567
- ),
568
- bias_init=<function zeros at 0x7fb32b98c2c0>,
569
- conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
570
- dtype=float32,
571
- feature_group_count=1,
572
- in_features=128,
573
- input_dilation=1,
574
- kernel=Param( # 36,864 (147.5 KB)
575
- value=Array(shape=(3, 3, 64, 64), dtype=dtype('float32'))
576
- ),
577
- kernel_dilation=1,
578
- kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
579
- kernel_shape=(3, 3, 128, 128),
580
- kernel_size=(3, 3),
581
- mask=None,
582
- out_features=128,
583
- padding=(1, 1),
584
- param_dtype=float32,
585
- precision=None,
586
- promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
587
- strides=(1, 1),
588
- use_bias=True
589
- ),
590
- dropout=Dropout( # RngState: 2 (12 B)
591
- broadcast_dims=(),
592
- deterministic=False,
593
- rate=0.1,
594
- rng_collection='dropout',
595
- rngs=RngStream( # RngState: 2 (12 B)
596
- count=RngCount( # 1 (4 B)
597
- value=Array(0, dtype=uint32),
598
- tag='default'
599
- ),
600
- key=RngKey( # 1 (8 B)
601
- value=Array((), dtype=key<fry>) overlaying:
602
- [ 692128146 1043829861],
603
- tag='default'
604
- ),
605
- tag='default'
606
- )
607
- ),
608
- embedding_dim=256,
609
- in_channels=128,
610
- norm1=GroupNorm( # Param: 128 (512 B)
611
- axis_index_groups=None,
612
- axis_name=None,
613
- bias=Param( # 64 (256 B)
614
- value=Array(shape=(64,), dtype=dtype('float32'))
615
- ),
616
- bias_init=<function zeros at 0x7fb32b98c2c0>,
617
- dtype=float32,
618
- epsilon=1e-06,
619
- feature_axis=-1,
620
- group_size=4,
621
- num_groups=32,
622
- param_dtype=float32,
623
- reduction_axes=None,
624
- scale=Param( # 64 (256 B)
625
- value=Array(shape=(64,), dtype=dtype('float32'))
626
- ),
627
- scale_init=<function ones at 0x7fb32b86e520>,
628
- use_bias=True,
629
- use_fast_variance=True,
630
- use_scale=True
631
- ),
632
- norm2=GroupNorm( # Param: 128 (512 B)
633
- axis_index_groups=None,
634
- axis_name=None,
635
- bias=Param( # 64 (256 B)
636
- value=Array(shape=(64,), dtype=dtype('float32'))
637
- ),
638
- bias_init=<function zeros at 0x7fb32b98c2c0>,
639
- dtype=float32,
640
- epsilon=1e-06,
641
- feature_axis=-1,
642
- group_size=4,
643
- num_groups=32,
644
- param_dtype=float32,
645
- reduction_axes=None,
646
- scale=Param( # 64 (256 B)
647
- value=Array(shape=(64,), dtype=dtype('float32'))
648
- ),
649
- scale_init=<function ones at 0x7fb32b86e520>,
650
- use_bias=True,
651
- use_fast_variance=True,
652
- use_scale=True
653
- ),
654
- out_channels=128,
655
- time_mlp=Linear( # Param: 16,512 (66.0 KB)
656
- bias=Param( # 128 (512 B)
657
- value=Array(shape=(128,), dtype=dtype('float32'))
658
- ),
659
- bias_init=<function zeros at 0x7fb32b98c2c0>,
660
- dot_general=<function dot_general at 0x7fb32c21a3e0>,
661
- dtype=float32,
662
- in_features=256,
663
- kernel=Param( # 16,384 (65.5 KB)
664
- value=Array(shape=(128, 128), dtype=dtype('float32'))
665
- ),
666
- kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
667
- out_features=256,
668
- param_dtype=float32,
669
- precision=None,
670
- promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
671
- use_bias=True
672
- )
673
- )]
674
- )],
675
- blocks_up=[TimestepEmbedSequential( # Param: 135,872 (543.5 KB), RngState: 2 (12 B), Total: 135,874 (543.5 KB)
676
- layers=[ResnetBlock( # Param: 135,872 (543.5 KB), RngState: 2 (12 B), Total: 135,874 (543.5 KB)
677
- activation=<PjitFunction of <function silu at 0x7fb32b8440e0>>,
678
- conv1=Conv( # Param: 73,792 (295.2 KB)
679
- bias=Param( # 64 (256 B)
680
- value=Array(shape=(64,), dtype=dtype('float32'))
681
- ),
682
- bias_init=<function zeros at 0x7fb32b98c2c0>,
683
- conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
684
- dtype=float32,
685
- feature_group_count=1,
686
- in_features=256,
687
- input_dilation=1,
688
- kernel=Param( # 73,728 (294.9 KB)
689
- value=Array(shape=(3, 3, 128, 64), dtype=dtype('float32'))
690
- ),
691
- kernel_dilation=1,
692
- kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
693
- kernel_shape=(3, 3, 256, 128),
694
- kernel_size=(3, 3),
695
- mask=None,
696
- out_features=128,
697
- padding=(1, 1),
698
- param_dtype=float32,
699
- precision=None,
700
- promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
701
- strides=(1, 1),
702
- use_bias=True
703
- ),
704
- conv2=Conv( # Param: 36,928 (147.7 KB)
705
- bias=Param( # 64 (256 B)
706
- value=Array(shape=(64,), dtype=dtype('float32'))
707
- ),
708
- bias_init=<function zeros at 0x7fb32b98c2c0>,
709
- conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
710
- dtype=float32,
711
- feature_group_count=1,
712
- in_features=128,
713
- input_dilation=1,
714
- kernel=Param( # 36,864 (147.5 KB)
715
- value=Array(shape=(3, 3, 64, 64), dtype=dtype('float32'))
716
- ),
717
- kernel_dilation=1,
718
- kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
719
- kernel_shape=(3, 3, 128, 128),
720
- kernel_size=(3, 3),
721
- mask=None,
722
- out_features=128,
723
- padding=(1, 1),
724
- param_dtype=float32,
725
- precision=None,
726
- promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
727
- strides=(1, 1),
728
- use_bias=True
729
- ),
730
- dropout=Dropout( # RngState: 2 (12 B)
731
- broadcast_dims=(),
732
- deterministic=False,
733
- rate=0.1,
734
- rng_collection='dropout',
735
- rngs=RngStream( # RngState: 2 (12 B)
736
- count=RngCount( # 1 (4 B)
737
- value=Array(0, dtype=uint32),
738
- tag='default'
739
- ),
740
- key=RngKey( # 1 (8 B)
741
- value=Array((), dtype=key<fry>) overlaying:
742
- [2853902436 2217684095],
743
- tag='default'
744
- ),
745
- tag='default'
746
- )
747
- ),
748
- embedding_dim=256,
749
- in_channels=256,
750
- nin_shortcut=Conv( # Param: 8,256 (33.0 KB)
751
- bias=Param( # 64 (256 B)
752
- value=Array(shape=(64,), dtype=dtype('float32'))
753
- ),
754
- bias_init=<function zeros at 0x7fb32b98c2c0>,
755
- conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
756
- dtype=float32,
757
- feature_group_count=1,
758
- in_features=256,
759
- input_dilation=1,
760
- kernel=Param( # 8,192 (32.8 KB)
761
- value=Array(shape=(1, 1, 128, 64), dtype=dtype('float32'))
762
- ),
763
- kernel_dilation=1,
764
- kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
765
- kernel_shape=(1, 1, 256, 128),
766
- kernel_size=(1, 1),
767
- mask=None,
768
- out_features=128,
769
- padding=(0, 0),
770
- param_dtype=float32,
771
- precision=None,
772
- promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
773
- strides=(1, 1),
774
- use_bias=True
775
- ),
776
- norm1=GroupNorm( # Param: 256 (1.0 KB)
777
- axis_index_groups=None,
778
- axis_name=None,
779
- bias=Param( # 128 (512 B)
780
- value=Array(shape=(128,), dtype=dtype('float32'))
781
- ),
782
- bias_init=<function zeros at 0x7fb32b98c2c0>,
783
- dtype=float32,
784
- epsilon=1e-06,
785
- feature_axis=-1,
786
- group_size=8,
787
- num_groups=32,
788
- param_dtype=float32,
789
- reduction_axes=None,
790
- scale=Param( # 128 (512 B)
791
- value=Array(shape=(128,), dtype=dtype('float32'))
792
- ),
793
- scale_init=<function ones at 0x7fb32b86e520>,
794
- use_bias=True,
795
- use_fast_variance=True,
796
- use_scale=True
797
- ),
798
- norm2=GroupNorm( # Param: 128 (512 B)
799
- axis_index_groups=None,
800
- axis_name=None,
801
- bias=Param( # 64 (256 B)
802
- value=Array(shape=(64,), dtype=dtype('float32'))
803
- ),
804
- bias_init=<function zeros at 0x7fb32b98c2c0>,
805
- dtype=float32,
806
- epsilon=1e-06,
807
- feature_axis=-1,
808
- group_size=4,
809
- num_groups=32,
810
- param_dtype=float32,
811
- reduction_axes=None,
812
- scale=Param( # 64 (256 B)
813
- value=Array(shape=(64,), dtype=dtype('float32'))
814
- ),
815
- scale_init=<function ones at 0x7fb32b86e520>,
816
- use_bias=True,
817
- use_fast_variance=True,
818
- use_scale=True
819
- ),
820
- out_channels=128,
821
- time_mlp=Linear( # Param: 16,512 (66.0 KB)
822
- bias=Param( # 128 (512 B)
823
- value=Array(shape=(128,), dtype=dtype('float32'))
824
- ),
825
- bias_init=<function zeros at 0x7fb32b98c2c0>,
826
- dot_general=<function dot_general at 0x7fb32c21a3e0>,
827
- dtype=float32,
828
- in_features=256,
829
- kernel=Param( # 16,384 (65.5 KB)
830
- value=Array(shape=(128, 128), dtype=dtype('float32'))
831
- ),
832
- kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
833
- out_features=256,
834
- param_dtype=float32,
835
- precision=None,
836
- promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
837
- use_bias=True
838
- )
839
- )]
840
- ), TimestepEmbedSequential( # Param: 320,512 (1.3 MB), RngState: 2 (12 B), Total: 320,514 (1.3 MB)
841
- layers=[ResnetBlock( # Param: 135,872 (543.5 KB), RngState: 2 (12 B), Total: 135,874 (543.5 KB)
842
- activation=<PjitFunction of <function silu at 0x7fb32b8440e0>>,
843
- conv1=Conv( # Param: 73,792 (295.2 KB)
844
- bias=Param( # 64 (256 B)
845
- value=Array(shape=(64,), dtype=dtype('float32'))
846
- ),
847
- bias_init=<function zeros at 0x7fb32b98c2c0>,
848
- conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
849
- dtype=float32,
850
- feature_group_count=1,
851
- in_features=256,
852
- input_dilation=1,
853
- kernel=Param( # 73,728 (294.9 KB)
854
- value=Array(shape=(3, 3, 128, 64), dtype=dtype('float32'))
855
- ),
856
- kernel_dilation=1,
857
- kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
858
- kernel_shape=(3, 3, 256, 128),
859
- kernel_size=(3, 3),
860
- mask=None,
861
- out_features=128,
862
- padding=(1, 1),
863
- param_dtype=float32,
864
- precision=None,
865
- promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
866
- strides=(1, 1),
867
- use_bias=True
868
- ),
869
- conv2=Conv( # Param: 36,928 (147.7 KB)
870
- bias=Param( # 64 (256 B)
871
- value=Array(shape=(64,), dtype=dtype('float32'))
872
- ),
873
- bias_init=<function zeros at 0x7fb32b98c2c0>,
874
- conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
875
- dtype=float32,
876
- feature_group_count=1,
877
- in_features=128,
878
- input_dilation=1,
879
- kernel=Param( # 36,864 (147.5 KB)
880
- value=Array(shape=(3, 3, 64, 64), dtype=dtype('float32'))
881
- ),
882
- kernel_dilation=1,
883
- kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
884
- kernel_shape=(3, 3, 128, 128),
885
- kernel_size=(3, 3),
886
- mask=None,
887
- out_features=128,
888
- padding=(1, 1),
889
- param_dtype=float32,
890
- precision=None,
891
- promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
892
- strides=(1, 1),
893
- use_bias=True
894
- ),
895
- dropout=Dropout( # RngState: 2 (12 B)
896
- broadcast_dims=(),
897
- deterministic=False,
898
- rate=0.1,
899
- rng_collection='dropout',
900
- rngs=RngStream( # RngState: 2 (12 B)
901
- count=RngCount( # 1 (4 B)
902
- value=Array(0, dtype=uint32),
903
- tag='default'
904
- ),
905
- key=RngKey( # 1 (8 B)
906
- value=Array((), dtype=key<fry>) overlaying:
907
- [2785098898 841100811],
908
- tag='default'
909
- ),
910
- tag='default'
911
- )
912
- ),
913
- embedding_dim=256,
914
- in_channels=256,
915
- nin_shortcut=Conv( # Param: 8,256 (33.0 KB)
916
- bias=Param( # 64 (256 B)
917
- value=Array(shape=(64,), dtype=dtype('float32'))
918
- ),
919
- bias_init=<function zeros at 0x7fb32b98c2c0>,
920
- conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
921
- dtype=float32,
922
- feature_group_count=1,
923
- in_features=256,
924
- input_dilation=1,
925
- kernel=Param( # 8,192 (32.8 KB)
926
- value=Array(shape=(1, 1, 128, 64), dtype=dtype('float32'))
927
- ),
928
- kernel_dilation=1,
929
- kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
930
- kernel_shape=(1, 1, 256, 128),
931
- kernel_size=(1, 1),
932
- mask=None,
933
- out_features=128,
934
- padding=(0, 0),
935
- param_dtype=float32,
936
- precision=None,
937
- promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
938
- strides=(1, 1),
939
- use_bias=True
940
- ),
941
- norm1=GroupNorm( # Param: 256 (1.0 KB)
942
- axis_index_groups=None,
943
- axis_name=None,
944
- bias=Param( # 128 (512 B)
945
- value=Array(shape=(128,), dtype=dtype('float32'))
946
- ),
947
- bias_init=<function zeros at 0x7fb32b98c2c0>,
948
- dtype=float32,
949
- epsilon=1e-06,
950
- feature_axis=-1,
951
- group_size=8,
952
- num_groups=32,
953
- param_dtype=float32,
954
- reduction_axes=None,
955
- scale=Param( # 128 (512 B)
956
- value=Array(shape=(128,), dtype=dtype('float32'))
957
- ),
958
- scale_init=<function ones at 0x7fb32b86e520>,
959
- use_bias=True,
960
- use_fast_variance=True,
961
- use_scale=True
962
- ),
963
- norm2=GroupNorm( # Param: 128 (512 B)
964
- axis_index_groups=None,
965
- axis_name=None,
966
- bias=Param( # 64 (256 B)
967
- value=Array(shape=(64,), dtype=dtype('float32'))
968
- ),
969
- bias_init=<function zeros at 0x7fb32b98c2c0>,
970
- dtype=float32,
971
- epsilon=1e-06,
972
- feature_axis=-1,
973
- group_size=4,
974
- num_groups=32,
975
- param_dtype=float32,
976
- reduction_axes=None,
977
- scale=Param( # 64 (256 B)
978
- value=Array(shape=(64,), dtype=dtype('float32'))
979
- ),
980
- scale_init=<function ones at 0x7fb32b86e520>,
981
- use_bias=True,
982
- use_fast_variance=True,
983
- use_scale=True
984
- ),
985
- out_channels=128,
986
- time_mlp=Linear( # Param: 16,512 (66.0 KB)
987
- bias=Param( # 128 (512 B)
988
- value=Array(shape=(128,), dtype=dtype('float32'))
989
- ),
990
- bias_init=<function zeros at 0x7fb32b98c2c0>,
991
- dot_general=<function dot_general at 0x7fb32c21a3e0>,
992
- dtype=float32,
993
- in_features=256,
994
- kernel=Param( # 16,384 (65.5 KB)
995
- value=Array(shape=(128, 128), dtype=dtype('float32'))
996
- ),
997
- kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
998
- out_features=256,
999
- param_dtype=float32,
1000
- precision=None,
1001
- promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
1002
- use_bias=True
1003
- )
1004
- ), Upsample( # Param: 184,640 (738.6 KB)
1005
- conv_pixel_shuffle=Conv( # Param: 147,712 (590.8 KB)
1006
- bias=Param( # 256 (1.0 KB)
1007
- value=Array(shape=(256,), dtype=dtype('float32'))
1008
- ),
1009
- bias_init=<function zeros at 0x7fb32b98c2c0>,
1010
- conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
1011
- dtype=float32,
1012
- feature_group_count=1,
1013
- in_features=128,
1014
- input_dilation=1,
1015
- kernel=Param( # 147,456 (589.8 KB)
1016
- value=Array(shape=(3, 3, 64, 256), dtype=dtype('float32'))
1017
- ),
1018
- kernel_dilation=1,
1019
- kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
1020
- kernel_shape=(3, 3, 128, 512),
1021
- kernel_size=(3, 3),
1022
- mask=None,
1023
- out_features=512,
1024
- padding=(1, 1),
1025
- param_dtype=float32,
1026
- precision=None,
1027
- promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
1028
- strides=(1, 1),
1029
- use_bias=True
1030
- ),
1031
- conv_resize=Conv( # Param: 36,928 (147.7 KB)
1032
- bias=Param( # 64 (256 B)
1033
- value=Array(shape=(64,), dtype=dtype('float32'))
1034
- ),
1035
- bias_init=<function zeros at 0x7fb32b98c2c0>,
1036
- conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
1037
- dtype=float32,
1038
- feature_group_count=1,
1039
- in_features=128,
1040
- input_dilation=1,
1041
- kernel=Param( # 36,864 (147.5 KB)
1042
- value=Array(shape=(3, 3, 64, 64), dtype=dtype('float32'))
1043
- ),
1044
- kernel_dilation=1,
1045
- kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
1046
- kernel_shape=(3, 3, 128, 128),
1047
- kernel_size=(3, 3),
1048
- mask=None,
1049
- out_features=128,
1050
- padding=(1, 1),
1051
- param_dtype=float32,
1052
- precision=None,
1053
- promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
1054
- strides=(1, 1),
1055
- use_bias=True
1056
- ),
1057
- method='pixel_shuffle',
1058
- pixel_shuffle=PixelShuffle(
1059
- scale=2
1060
- ),
1061
- scale_factor=2
1062
- )]
1063
- ), TimestepEmbedSequential( # Param: 152,640 (610.6 KB), RngState: 2 (12 B), Total: 152,642 (610.6 KB)
1064
- layers=[ResnetBlock( # Param: 135,872 (543.5 KB), RngState: 2 (12 B), Total: 135,874 (543.5 KB)
1065
- activation=<PjitFunction of <function silu at 0x7fb32b8440e0>>,
1066
- conv1=Conv( # Param: 73,792 (295.2 KB)
1067
- bias=Param( # 64 (256 B)
1068
- value=Array(shape=(64,), dtype=dtype('float32'))
1069
- ),
1070
- bias_init=<function zeros at 0x7fb32b98c2c0>,
1071
- conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
1072
- dtype=float32,
1073
- feature_group_count=1,
1074
- in_features=256,
1075
- input_dilation=1,
1076
- kernel=Param( # 73,728 (294.9 KB)
1077
- value=Array(shape=(3, 3, 128, 64), dtype=dtype('float32'))
1078
- ),
1079
- kernel_dilation=1,
1080
- kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
1081
- kernel_shape=(3, 3, 256, 128),
1082
- kernel_size=(3, 3),
1083
- mask=None,
1084
- out_features=128,
1085
- padding=(1, 1),
1086
- param_dtype=float32,
1087
- precision=None,
1088
- promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
1089
- strides=(1, 1),
1090
- use_bias=True
1091
- ),
1092
- conv2=Conv( # Param: 36,928 (147.7 KB)
1093
- bias=Param( # 64 (256 B)
1094
- value=Array(shape=(64,), dtype=dtype('float32'))
1095
- ),
1096
- bias_init=<function zeros at 0x7fb32b98c2c0>,
1097
- conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
1098
- dtype=float32,
1099
- feature_group_count=1,
1100
- in_features=128,
1101
- input_dilation=1,
1102
- kernel=Param( # 36,864 (147.5 KB)
1103
- value=Array(shape=(3, 3, 64, 64), dtype=dtype('float32'))
1104
- ),
1105
- kernel_dilation=1,
1106
- kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
1107
- kernel_shape=(3, 3, 128, 128),
1108
- kernel_size=(3, 3),
1109
- mask=None,
1110
- out_features=128,
1111
- padding=(1, 1),
1112
- param_dtype=float32,
1113
- precision=None,
1114
- promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
1115
- strides=(1, 1),
1116
- use_bias=True
1117
- ),
1118
- dropout=Dropout( # RngState: 2 (12 B)
1119
- broadcast_dims=(),
1120
- deterministic=False,
1121
- rate=0.1,
1122
- rng_collection='dropout',
1123
- rngs=RngStream( # RngState: 2 (12 B)
1124
- count=RngCount( # 1 (4 B)
1125
- value=Array(0, dtype=uint32),
1126
- tag='default'
1127
- ),
1128
- key=RngKey( # 1 (8 B)
1129
- value=Array((), dtype=key<fry>) overlaying:
1130
- [ 48802331 1548237274],
1131
- tag='default'
1132
- ),
1133
- tag='default'
1134
- )
1135
- ),
1136
- embedding_dim=256,
1137
- in_channels=256,
1138
- nin_shortcut=Conv( # Param: 8,256 (33.0 KB)
1139
- bias=Param( # 64 (256 B)
1140
- value=Array(shape=(64,), dtype=dtype('float32'))
1141
- ),
1142
- bias_init=<function zeros at 0x7fb32b98c2c0>,
1143
- conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
1144
- dtype=float32,
1145
- feature_group_count=1,
1146
- in_features=256,
1147
- input_dilation=1,
1148
- kernel=Param( # 8,192 (32.8 KB)
1149
- value=Array(shape=(1, 1, 128, 64), dtype=dtype('float32'))
1150
- ),
1151
- kernel_dilation=1,
1152
- kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
1153
- kernel_shape=(1, 1, 256, 128),
1154
- kernel_size=(1, 1),
1155
- mask=None,
1156
- out_features=128,
1157
- padding=(0, 0),
1158
- param_dtype=float32,
1159
- precision=None,
1160
- promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
1161
- strides=(1, 1),
1162
- use_bias=True
1163
- ),
1164
- norm1=GroupNorm( # Param: 256 (1.0 KB)
1165
- axis_index_groups=None,
1166
- axis_name=None,
1167
- bias=Param( # 128 (512 B)
1168
- value=Array(shape=(128,), dtype=dtype('float32'))
1169
- ),
1170
- bias_init=<function zeros at 0x7fb32b98c2c0>,
1171
- dtype=float32,
1172
- epsilon=1e-06,
1173
- feature_axis=-1,
1174
- group_size=8,
1175
- num_groups=32,
1176
- param_dtype=float32,
1177
- reduction_axes=None,
1178
- scale=Param( # 128 (512 B)
1179
- value=Array(shape=(128,), dtype=dtype('float32'))
1180
- ),
1181
- scale_init=<function ones at 0x7fb32b86e520>,
1182
- use_bias=True,
1183
- use_fast_variance=True,
1184
- use_scale=True
1185
- ),
1186
- norm2=GroupNorm( # Param: 128 (512 B)
1187
- axis_index_groups=None,
1188
- axis_name=None,
1189
- bias=Param( # 64 (256 B)
1190
- value=Array(shape=(64,), dtype=dtype('float32'))
1191
- ),
1192
- bias_init=<function zeros at 0x7fb32b98c2c0>,
1193
- dtype=float32,
1194
- epsilon=1e-06,
1195
- feature_axis=-1,
1196
- group_size=4,
1197
- num_groups=32,
1198
- param_dtype=float32,
1199
- reduction_axes=None,
1200
- scale=Param( # 64 (256 B)
1201
- value=Array(shape=(64,), dtype=dtype('float32'))
1202
- ),
1203
- scale_init=<function ones at 0x7fb32b86e520>,
1204
- use_bias=True,
1205
- use_fast_variance=True,
1206
- use_scale=True
1207
- ),
1208
- out_channels=128,
1209
- time_mlp=Linear( # Param: 16,512 (66.0 KB)
1210
- bias=Param( # 128 (512 B)
1211
- value=Array(shape=(128,), dtype=dtype('float32'))
1212
- ),
1213
- bias_init=<function zeros at 0x7fb32b98c2c0>,
1214
- dot_general=<function dot_general at 0x7fb32c21a3e0>,
1215
- dtype=float32,
1216
- in_features=256,
1217
- kernel=Param( # 16,384 (65.5 KB)
1218
- value=Array(shape=(128, 128), dtype=dtype('float32'))
1219
- ),
1220
- kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
1221
- out_features=256,
1222
- param_dtype=float32,
1223
- precision=None,
1224
- promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
1225
- use_bias=True
1226
- )
1227
- ), AttnBlock( # Param: 16,768 (67.1 KB)
1228
- dtype=float32,
1229
- head_dim=32,
1230
- k=Conv( # Param: 4,160 (16.6 KB)
1231
- bias=Param( # 64 (256 B)
1232
- value=Array(shape=(64,), dtype=dtype('float32'))
1233
- ),
1234
- bias_init=<function zeros at 0x7fb32b98c2c0>,
1235
- conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
1236
- dtype=float32,
1237
- feature_group_count=1,
1238
- in_features=128,
1239
- input_dilation=1,
1240
- kernel=Param( # 4,096 (16.4 KB)
1241
- value=Array(shape=(1, 1, 64, 64), dtype=dtype('float32'))
1242
- ),
1243
- kernel_dilation=1,
1244
- kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
1245
- kernel_shape=(1, 1, 128, 128),
1246
- kernel_size=(1, 1),
1247
- mask=None,
1248
- out_features=128,
1249
- padding='SAME',
1250
- param_dtype=float32,
1251
- precision=None,
1252
- promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
1253
- strides=1,
1254
- use_bias=True
1255
- ),
1256
- norm=GroupNorm( # Param: 128 (512 B)
1257
- axis_index_groups=None,
1258
- axis_name=None,
1259
- bias=Param( # 64 (256 B)
1260
- value=Array(shape=(64,), dtype=dtype('float32'))
1261
- ),
1262
- bias_init=<function zeros at 0x7fb32b98c2c0>,
1263
- dtype=float32,
1264
- epsilon=1e-06,
1265
- feature_axis=-1,
1266
- group_size=4,
1267
- num_groups=32,
1268
- param_dtype=float32,
1269
- reduction_axes=None,
1270
- scale=Param( # 64 (256 B)
1271
- value=Array(shape=(64,), dtype=dtype('float32'))
1272
- ),
1273
- scale_init=<function ones at 0x7fb32b86e520>,
1274
- use_bias=True,
1275
- use_fast_variance=True,
1276
- use_scale=True
1277
- ),
1278
- num_heads=4,
1279
- proj_out=Conv( # Param: 4,160 (16.6 KB)
1280
- bias=Param( # 64 (256 B)
1281
- value=Array(shape=(64,), dtype=dtype('float32'))
1282
- ),
1283
- bias_init=<function zeros at 0x7fb32b98c2c0>,
1284
- conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
1285
- dtype=float32,
1286
- feature_group_count=1,
1287
- in_features=128,
1288
- input_dilation=1,
1289
- kernel=Param( # 4,096 (16.4 KB)
1290
- value=Array(shape=(1, 1, 64, 64), dtype=dtype('float32'))
1291
- ),
1292
- kernel_dilation=1,
1293
- kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
1294
- kernel_shape=(1, 1, 128, 128),
1295
- kernel_size=(1, 1),
1296
- mask=None,
1297
- out_features=128,
1298
- padding='SAME',
1299
- param_dtype=float32,
1300
- precision=None,
1301
- promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
1302
- strides=1,
1303
- use_bias=True
1304
- ),
1305
- q=Conv( # Param: 4,160 (16.6 KB)
1306
- bias=Param( # 64 (256 B)
1307
- value=Array(shape=(64,), dtype=dtype('float32'))
1308
- ),
1309
- bias_init=<function zeros at 0x7fb32b98c2c0>,
1310
- conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
1311
- dtype=float32,
1312
- feature_group_count=1,
1313
- in_features=128,
1314
- input_dilation=1,
1315
- kernel=Param( # 4,096 (16.4 KB)
1316
- value=Array(shape=(1, 1, 64, 64), dtype=dtype('float32'))
1317
- ),
1318
- kernel_dilation=1,
1319
- kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
1320
- kernel_shape=(1, 1, 128, 128),
1321
- kernel_size=(1, 1),
1322
- mask=None,
1323
- out_features=128,
1324
- padding='SAME',
1325
- param_dtype=float32,
1326
- precision=None,
1327
- promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
1328
- strides=1,
1329
- use_bias=True
1330
- ),
1331
- v=Conv( # Param: 4,160 (16.6 KB)
1332
- bias=Param( # 64 (256 B)
1333
- value=Array(shape=(64,), dtype=dtype('float32'))
1334
- ),
1335
- bias_init=<function zeros at 0x7fb32b98c2c0>,
1336
- conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
1337
- dtype=float32,
1338
- feature_group_count=1,
1339
- in_features=128,
1340
- input_dilation=1,
1341
- kernel=Param( # 4,096 (16.4 KB)
1342
- value=Array(shape=(1, 1, 64, 64), dtype=dtype('float32'))
1343
- ),
1344
- kernel_dilation=1,
1345
- kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
1346
- kernel_shape=(1, 1, 128, 128),
1347
- kernel_size=(1, 1),
1348
- mask=None,
1349
- out_features=128,
1350
- padding='SAME',
1351
- param_dtype=float32,
1352
- precision=None,
1353
- promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
1354
- strides=1,
1355
- use_bias=True
1356
- )
1357
- )]
1358
- ), TimestepEmbedSequential( # Param: 316,736 (1.3 MB), RngState: 2 (12 B), Total: 316,738 (1.3 MB)
1359
- layers=[ResnetBlock( # Param: 115,328 (461.3 KB), RngState: 2 (12 B), Total: 115,330 (461.3 KB)
1360
- activation=<PjitFunction of <function silu at 0x7fb32b8440e0>>,
1361
- conv1=Conv( # Param: 55,360 (221.4 KB)
1362
- bias=Param( # 64 (256 B)
1363
- value=Array(shape=(64,), dtype=dtype('float32'))
1364
- ),
1365
- bias_init=<function zeros at 0x7fb32b98c2c0>,
1366
- conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
1367
- dtype=float32,
1368
- feature_group_count=1,
1369
- in_features=192,
1370
- input_dilation=1,
1371
- kernel=Param( # 55,296 (221.2 KB)
1372
- value=Array(shape=(3, 3, 96, 64), dtype=dtype('float32'))
1373
- ),
1374
- kernel_dilation=1,
1375
- kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
1376
- kernel_shape=(3, 3, 192, 128),
1377
- kernel_size=(3, 3),
1378
- mask=None,
1379
- out_features=128,
1380
- padding=(1, 1),
1381
- param_dtype=float32,
1382
- precision=None,
1383
- promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
1384
- strides=(1, 1),
1385
- use_bias=True
1386
- ),
1387
- conv2=Conv( # Param: 36,928 (147.7 KB)
1388
- bias=Param( # 64 (256 B)
1389
- value=Array(shape=(64,), dtype=dtype('float32'))
1390
- ),
1391
- bias_init=<function zeros at 0x7fb32b98c2c0>,
1392
- conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
1393
- dtype=float32,
1394
- feature_group_count=1,
1395
- in_features=128,
1396
- input_dilation=1,
1397
- kernel=Param( # 36,864 (147.5 KB)
1398
- value=Array(shape=(3, 3, 64, 64), dtype=dtype('float32'))
1399
- ),
1400
- kernel_dilation=1,
1401
- kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
1402
- kernel_shape=(3, 3, 128, 128),
1403
- kernel_size=(3, 3),
1404
- mask=None,
1405
- out_features=128,
1406
- padding=(1, 1),
1407
- param_dtype=float32,
1408
- precision=None,
1409
- promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
1410
- strides=(1, 1),
1411
- use_bias=True
1412
- ),
1413
- dropout=Dropout( # RngState: 2 (12 B)
1414
- broadcast_dims=(),
1415
- deterministic=False,
1416
- rate=0.1,
1417
- rng_collection='dropout',
1418
- rngs=RngStream( # RngState: 2 (12 B)
1419
- count=RngCount( # 1 (4 B)
1420
- value=Array(0, dtype=uint32),
1421
- tag='default'
1422
- ),
1423
- key=RngKey( # 1 (8 B)
1424
- value=Array((), dtype=key<fry>) overlaying:
1425
- [1596966061 1315822572],
1426
- tag='default'
1427
- ),
1428
- tag='default'
1429
- )
1430
- ),
1431
- embedding_dim=256,
1432
- in_channels=192,
1433
- nin_shortcut=Conv( # Param: 6,208 (24.8 KB)
1434
- bias=Param( # 64 (256 B)
1435
- value=Array(shape=(64,), dtype=dtype('float32'))
1436
- ),
1437
- bias_init=<function zeros at 0x7fb32b98c2c0>,
1438
- conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
1439
- dtype=float32,
1440
- feature_group_count=1,
1441
- in_features=192,
1442
- input_dilation=1,
1443
- kernel=Param( # 6,144 (24.6 KB)
1444
- value=Array(shape=(1, 1, 96, 64), dtype=dtype('float32'))
1445
- ),
1446
- kernel_dilation=1,
1447
- kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
1448
- kernel_shape=(1, 1, 192, 128),
1449
- kernel_size=(1, 1),
1450
- mask=None,
1451
- out_features=128,
1452
- padding=(0, 0),
1453
- param_dtype=float32,
1454
- precision=None,
1455
- promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
1456
- strides=(1, 1),
1457
- use_bias=True
1458
- ),
1459
- norm1=GroupNorm( # Param: 192 (768 B)
1460
- axis_index_groups=None,
1461
- axis_name=None,
1462
- bias=Param( # 96 (384 B)
1463
- value=Array(shape=(96,), dtype=dtype('float32'))
1464
- ),
1465
- bias_init=<function zeros at 0x7fb32b98c2c0>,
1466
- dtype=float32,
1467
- epsilon=1e-06,
1468
- feature_axis=-1,
1469
- group_size=6,
1470
- num_groups=32,
1471
- param_dtype=float32,
1472
- reduction_axes=None,
1473
- scale=Param( # 96 (384 B)
1474
- value=Array(shape=(96,), dtype=dtype('float32'))
1475
- ),
1476
- scale_init=<function ones at 0x7fb32b86e520>,
1477
- use_bias=True,
1478
- use_fast_variance=True,
1479
- use_scale=True
1480
- ),
1481
- norm2=GroupNorm( # Param: 128 (512 B)
1482
- axis_index_groups=None,
1483
- axis_name=None,
1484
- bias=Param( # 64 (256 B)
1485
- value=Array(shape=(64,), dtype=dtype('float32'))
1486
- ),
1487
- bias_init=<function zeros at 0x7fb32b98c2c0>,
1488
- dtype=float32,
1489
- epsilon=1e-06,
1490
- feature_axis=-1,
1491
- group_size=4,
1492
- num_groups=32,
1493
- param_dtype=float32,
1494
- reduction_axes=None,
1495
- scale=Param( # 64 (256 B)
1496
- value=Array(shape=(64,), dtype=dtype('float32'))
1497
- ),
1498
- scale_init=<function ones at 0x7fb32b86e520>,
1499
- use_bias=True,
1500
- use_fast_variance=True,
1501
- use_scale=True
1502
- ),
1503
- out_channels=128,
1504
- time_mlp=Linear( # Param: 16,512 (66.0 KB)
1505
- bias=Param( # 128 (512 B)
1506
- value=Array(shape=(128,), dtype=dtype('float32'))
1507
- ),
1508
- bias_init=<function zeros at 0x7fb32b98c2c0>,
1509
- dot_general=<function dot_general at 0x7fb32c21a3e0>,
1510
- dtype=float32,
1511
- in_features=256,
1512
- kernel=Param( # 16,384 (65.5 KB)
1513
- value=Array(shape=(128, 128), dtype=dtype('float32'))
1514
- ),
1515
- kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
1516
- out_features=256,
1517
- param_dtype=float32,
1518
- precision=None,
1519
- promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
1520
- use_bias=True
1521
- )
1522
- ), AttnBlock( # Param: 16,768 (67.1 KB)
1523
- dtype=float32,
1524
- head_dim=32,
1525
- k=Conv( # Param: 4,160 (16.6 KB)
1526
- bias=Param( # 64 (256 B)
1527
- value=Array(shape=(64,), dtype=dtype('float32'))
1528
- ),
1529
- bias_init=<function zeros at 0x7fb32b98c2c0>,
1530
- conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
1531
- dtype=float32,
1532
- feature_group_count=1,
1533
- in_features=128,
1534
- input_dilation=1,
1535
- kernel=Param( # 4,096 (16.4 KB)
1536
- value=Array(shape=(1, 1, 64, 64), dtype=dtype('float32'))
1537
- ),
1538
- kernel_dilation=1,
1539
- kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
1540
- kernel_shape=(1, 1, 128, 128),
1541
- kernel_size=(1, 1),
1542
- mask=None,
1543
- out_features=128,
1544
- padding='SAME',
1545
- param_dtype=float32,
1546
- precision=None,
1547
- promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
1548
- strides=1,
1549
- use_bias=True
1550
- ),
1551
- norm=GroupNorm( # Param: 128 (512 B)
1552
- axis_index_groups=None,
1553
- axis_name=None,
1554
- bias=Param( # 64 (256 B)
1555
- value=Array(shape=(64,), dtype=dtype('float32'))
1556
- ),
1557
- bias_init=<function zeros at 0x7fb32b98c2c0>,
1558
- dtype=float32,
1559
- epsilon=1e-06,
1560
- feature_axis=-1,
1561
- group_size=4,
1562
- num_groups=32,
1563
- param_dtype=float32,
1564
- reduction_axes=None,
1565
- scale=Param( # 64 (256 B)
1566
- value=Array(shape=(64,), dtype=dtype('float32'))
1567
- ),
1568
- scale_init=<function ones at 0x7fb32b86e520>,
1569
- use_bias=True,
1570
- use_fast_variance=True,
1571
- use_scale=True
1572
- ),
1573
- num_heads=4,
1574
- proj_out=Conv( # Param: 4,160 (16.6 KB)
1575
- bias=Param( # 64 (256 B)
1576
- value=Array(shape=(64,), dtype=dtype('float32'))
1577
- ),
1578
- bias_init=<function zeros at 0x7fb32b98c2c0>,
1579
- conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
1580
- dtype=float32,
1581
- feature_group_count=1,
1582
- in_features=128,
1583
- input_dilation=1,
1584
- kernel=Param( # 4,096 (16.4 KB)
1585
- value=Array(shape=(1, 1, 64, 64), dtype=dtype('float32'))
1586
- ),
1587
- kernel_dilation=1,
1588
- kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
1589
- kernel_shape=(1, 1, 128, 128),
1590
- kernel_size=(1, 1),
1591
- mask=None,
1592
- out_features=128,
1593
- padding='SAME',
1594
- param_dtype=float32,
1595
- precision=None,
1596
- promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
1597
- strides=1,
1598
- use_bias=True
1599
- ),
1600
- q=Conv( # Param: 4,160 (16.6 KB)
1601
- bias=Param( # 64 (256 B)
1602
- value=Array(shape=(64,), dtype=dtype('float32'))
1603
- ),
1604
- bias_init=<function zeros at 0x7fb32b98c2c0>,
1605
- conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
1606
- dtype=float32,
1607
- feature_group_count=1,
1608
- in_features=128,
1609
- input_dilation=1,
1610
- kernel=Param( # 4,096 (16.4 KB)
1611
- value=Array(shape=(1, 1, 64, 64), dtype=dtype('float32'))
1612
- ),
1613
- kernel_dilation=1,
1614
- kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
1615
- kernel_shape=(1, 1, 128, 128),
1616
- kernel_size=(1, 1),
1617
- mask=None,
1618
- out_features=128,
1619
- padding='SAME',
1620
- param_dtype=float32,
1621
- precision=None,
1622
- promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
1623
- strides=1,
1624
- use_bias=True
1625
- ),
1626
- v=Conv( # Param: 4,160 (16.6 KB)
1627
- bias=Param( # 64 (256 B)
1628
- value=Array(shape=(64,), dtype=dtype('float32'))
1629
- ),
1630
- bias_init=<function zeros at 0x7fb32b98c2c0>,
1631
- conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
1632
- dtype=float32,
1633
- feature_group_count=1,
1634
- in_features=128,
1635
- input_dilation=1,
1636
- kernel=Param( # 4,096 (16.4 KB)
1637
- value=Array(shape=(1, 1, 64, 64), dtype=dtype('float32'))
1638
- ),
1639
- kernel_dilation=1,
1640
- kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
1641
- kernel_shape=(1, 1, 128, 128),
1642
- kernel_size=(1, 1),
1643
- mask=None,
1644
- out_features=128,
1645
- padding='SAME',
1646
- param_dtype=float32,
1647
- precision=None,
1648
- promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
1649
- strides=1,
1650
- use_bias=True
1651
- )
1652
- ), Upsample( # Param: 184,640 (738.6 KB)
1653
- conv_pixel_shuffle=Conv( # Param: 147,712 (590.8 KB)
1654
- bias=Param( # 256 (1.0 KB)
1655
- value=Array(shape=(256,), dtype=dtype('float32'))
1656
- ),
1657
- bias_init=<function zeros at 0x7fb32b98c2c0>,
1658
- conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
1659
- dtype=float32,
1660
- feature_group_count=1,
1661
- in_features=128,
1662
- input_dilation=1,
1663
- kernel=Param( # 147,456 (589.8 KB)
1664
- value=Array(shape=(3, 3, 64, 256), dtype=dtype('float32'))
1665
- ),
1666
- kernel_dilation=1,
1667
- kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
1668
- kernel_shape=(3, 3, 128, 512),
1669
- kernel_size=(3, 3),
1670
- mask=None,
1671
- out_features=512,
1672
- padding=(1, 1),
1673
- param_dtype=float32,
1674
- precision=None,
1675
- promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
1676
- strides=(1, 1),
1677
- use_bias=True
1678
- ),
1679
- conv_resize=Conv( # Param: 36,928 (147.7 KB)
1680
- bias=Param( # 64 (256 B)
1681
- value=Array(shape=(64,), dtype=dtype('float32'))
1682
- ),
1683
- bias_init=<function zeros at 0x7fb32b98c2c0>,
1684
- conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
1685
- dtype=float32,
1686
- feature_group_count=1,
1687
- in_features=128,
1688
- input_dilation=1,
1689
- kernel=Param( # 36,864 (147.5 KB)
1690
- value=Array(shape=(3, 3, 64, 64), dtype=dtype('float32'))
1691
- ),
1692
- kernel_dilation=1,
1693
- kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
1694
- kernel_shape=(3, 3, 128, 128),
1695
- kernel_size=(3, 3),
1696
- mask=None,
1697
- out_features=128,
1698
- padding=(1, 1),
1699
- param_dtype=float32,
1700
- precision=None,
1701
- promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
1702
- strides=(1, 1),
1703
- use_bias=True
1704
- ),
1705
- method='pixel_shuffle',
1706
- pixel_shuffle=PixelShuffle(
1707
- scale=2
1708
- ),
1709
- scale_factor=2
1710
- )]
1711
- ), TimestepEmbedSequential( # Param: 48,544 (194.2 KB), RngState: 2 (12 B), Total: 48,546 (194.2 KB)
1712
- layers=[ResnetBlock( # Param: 48,544 (194.2 KB), RngState: 2 (12 B), Total: 48,546 (194.2 KB)
1713
- activation=<PjitFunction of <function silu at 0x7fb32b8440e0>>,
1714
- conv1=Conv( # Param: 27,680 (110.7 KB)
1715
- bias=Param( # 32 (128 B)
1716
- value=Array(shape=(32,), dtype=dtype('float32'))
1717
- ),
1718
- bias_init=<function zeros at 0x7fb32b98c2c0>,
1719
- conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
1720
- dtype=float32,
1721
- feature_group_count=1,
1722
- in_features=192,
1723
- input_dilation=1,
1724
- kernel=Param( # 27,648 (110.6 KB)
1725
- value=Array(shape=(3, 3, 96, 32), dtype=dtype('float32'))
1726
- ),
1727
- kernel_dilation=1,
1728
- kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
1729
- kernel_shape=(3, 3, 192, 64),
1730
- kernel_size=(3, 3),
1731
- mask=None,
1732
- out_features=64,
1733
- padding=(1, 1),
1734
- param_dtype=float32,
1735
- precision=None,
1736
- promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
1737
- strides=(1, 1),
1738
- use_bias=True
1739
- ),
1740
- conv2=Conv( # Param: 9,248 (37.0 KB)
1741
- bias=Param( # 32 (128 B)
1742
- value=Array(shape=(32,), dtype=dtype('float32'))
1743
- ),
1744
- bias_init=<function zeros at 0x7fb32b98c2c0>,
1745
- conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
1746
- dtype=float32,
1747
- feature_group_count=1,
1748
- in_features=64,
1749
- input_dilation=1,
1750
- kernel=Param( # 9,216 (36.9 KB)
1751
- value=Array(shape=(3, 3, 32, 32), dtype=dtype('float32'))
1752
- ),
1753
- kernel_dilation=1,
1754
- kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
1755
- kernel_shape=(3, 3, 64, 64),
1756
- kernel_size=(3, 3),
1757
- mask=None,
1758
- out_features=64,
1759
- padding=(1, 1),
1760
- param_dtype=float32,
1761
- precision=None,
1762
- promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
1763
- strides=(1, 1),
1764
- use_bias=True
1765
- ),
1766
- dropout=Dropout( # RngState: 2 (12 B)
1767
- broadcast_dims=(),
1768
- deterministic=False,
1769
- rate=0.1,
1770
- rng_collection='dropout',
1771
- rngs=RngStream( # RngState: 2 (12 B)
1772
- count=RngCount( # 1 (4 B)
1773
- value=Array(0, dtype=uint32),
1774
- tag='default'
1775
- ),
1776
- key=RngKey( # 1 (8 B)
1777
- value=Array((), dtype=key<fry>) overlaying:
1778
- [2550820645 2818876438],
1779
- tag='default'
1780
- ),
1781
- tag='default'
1782
- )
1783
- ),
1784
- embedding_dim=256,
1785
- in_channels=192,
1786
- nin_shortcut=Conv( # Param: 3,104 (12.4 KB)
1787
- bias=Param( # 32 (128 B)
1788
- value=Array(shape=(32,), dtype=dtype('float32'))
1789
- ),
1790
- bias_init=<function zeros at 0x7fb32b98c2c0>,
1791
- conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
1792
- dtype=float32,
1793
- feature_group_count=1,
1794
- in_features=192,
1795
- input_dilation=1,
1796
- kernel=Param( # 3,072 (12.3 KB)
1797
- value=Array(shape=(1, 1, 96, 32), dtype=dtype('float32'))
1798
- ),
1799
- kernel_dilation=1,
1800
- kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
1801
- kernel_shape=(1, 1, 192, 64),
1802
- kernel_size=(1, 1),
1803
- mask=None,
1804
- out_features=64,
1805
- padding=(0, 0),
1806
- param_dtype=float32,
1807
- precision=None,
1808
- promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
1809
- strides=(1, 1),
1810
- use_bias=True
1811
- ),
1812
- norm1=GroupNorm( # Param: 192 (768 B)
1813
- axis_index_groups=None,
1814
- axis_name=None,
1815
- bias=Param( # 96 (384 B)
1816
- value=Array(shape=(96,), dtype=dtype('float32'))
1817
- ),
1818
- bias_init=<function zeros at 0x7fb32b98c2c0>,
1819
- dtype=float32,
1820
- epsilon=1e-06,
1821
- feature_axis=-1,
1822
- group_size=6,
1823
- num_groups=32,
1824
- param_dtype=float32,
1825
- reduction_axes=None,
1826
- scale=Param( # 96 (384 B)
1827
- value=Array(shape=(96,), dtype=dtype('float32'))
1828
- ),
1829
- scale_init=<function ones at 0x7fb32b86e520>,
1830
- use_bias=True,
1831
- use_fast_variance=True,
1832
- use_scale=True
1833
- ),
1834
- norm2=GroupNorm( # Param: 64 (256 B)
1835
- axis_index_groups=None,
1836
- axis_name=None,
1837
- bias=Param( # 32 (128 B)
1838
- value=Array(shape=(32,), dtype=dtype('float32'))
1839
- ),
1840
- bias_init=<function zeros at 0x7fb32b98c2c0>,
1841
- dtype=float32,
1842
- epsilon=1e-06,
1843
- feature_axis=-1,
1844
- group_size=2,
1845
- num_groups=32,
1846
- param_dtype=float32,
1847
- reduction_axes=None,
1848
- scale=Param( # 32 (128 B)
1849
- value=Array(shape=(32,), dtype=dtype('float32'))
1850
- ),
1851
- scale_init=<function ones at 0x7fb32b86e520>,
1852
- use_bias=True,
1853
- use_fast_variance=True,
1854
- use_scale=True
1855
- ),
1856
- out_channels=64,
1857
- time_mlp=Linear( # Param: 8,256 (33.0 KB)
1858
- bias=Param( # 64 (256 B)
1859
- value=Array(shape=(64,), dtype=dtype('float32'))
1860
- ),
1861
- bias_init=<function zeros at 0x7fb32b98c2c0>,
1862
- dot_general=<function dot_general at 0x7fb32c21a3e0>,
1863
- dtype=float32,
1864
- in_features=256,
1865
- kernel=Param( # 8,192 (32.8 KB)
1866
- value=Array(shape=(128, 64), dtype=dtype('float32'))
1867
- ),
1868
- kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
1869
- out_features=128,
1870
- param_dtype=float32,
1871
- precision=None,
1872
- promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
1873
- use_bias=True
1874
- )
1875
- )]
1876
- ), TimestepEmbedSequential( # Param: 38,240 (153.0 KB), RngState: 2 (12 B), Total: 38,242 (153.0 KB)
1877
- layers=[ResnetBlock( # Param: 38,240 (153.0 KB), RngState: 2 (12 B), Total: 38,242 (153.0 KB)
1878
- activation=<PjitFunction of <function silu at 0x7fb32b8440e0>>,
1879
- conv1=Conv( # Param: 18,464 (73.9 KB)
1880
- bias=Param( # 32 (128 B)
1881
- value=Array(shape=(32,), dtype=dtype('float32'))
1882
- ),
1883
- bias_init=<function zeros at 0x7fb32b98c2c0>,
1884
- conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
1885
- dtype=float32,
1886
- feature_group_count=1,
1887
- in_features=128,
1888
- input_dilation=1,
1889
- kernel=Param( # 18,432 (73.7 KB)
1890
- value=Array(shape=(3, 3, 64, 32), dtype=dtype('float32'))
1891
- ),
1892
- kernel_dilation=1,
1893
- kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
1894
- kernel_shape=(3, 3, 128, 64),
1895
- kernel_size=(3, 3),
1896
- mask=None,
1897
- out_features=64,
1898
- padding=(1, 1),
1899
- param_dtype=float32,
1900
- precision=None,
1901
- promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
1902
- strides=(1, 1),
1903
- use_bias=True
1904
- ),
1905
- conv2=Conv( # Param: 9,248 (37.0 KB)
1906
- bias=Param( # 32 (128 B)
1907
- value=Array(shape=(32,), dtype=dtype('float32'))
1908
- ),
1909
- bias_init=<function zeros at 0x7fb32b98c2c0>,
1910
- conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
1911
- dtype=float32,
1912
- feature_group_count=1,
1913
- in_features=64,
1914
- input_dilation=1,
1915
- kernel=Param( # 9,216 (36.9 KB)
1916
- value=Array(shape=(3, 3, 32, 32), dtype=dtype('float32'))
1917
- ),
1918
- kernel_dilation=1,
1919
- kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
1920
- kernel_shape=(3, 3, 64, 64),
1921
- kernel_size=(3, 3),
1922
- mask=None,
1923
- out_features=64,
1924
- padding=(1, 1),
1925
- param_dtype=float32,
1926
- precision=None,
1927
- promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
1928
- strides=(1, 1),
1929
- use_bias=True
1930
- ),
1931
- dropout=Dropout( # RngState: 2 (12 B)
1932
- broadcast_dims=(),
1933
- deterministic=False,
1934
- rate=0.1,
1935
- rng_collection='dropout',
1936
- rngs=RngStream( # RngState: 2 (12 B)
1937
- count=RngCount( # 1 (4 B)
1938
- value=Array(0, dtype=uint32),
1939
- tag='default'
1940
- ),
1941
- key=RngKey( # 1 (8 B)
1942
- value=Array((), dtype=key<fry>) overlaying:
1943
- [1975238715 3717004500],
1944
- tag='default'
1945
- ),
1946
- tag='default'
1947
- )
1948
- ),
1949
- embedding_dim=256,
1950
- in_channels=128,
1951
- nin_shortcut=Conv( # Param: 2,080 (8.3 KB)
1952
- bias=Param( # 32 (128 B)
1953
- value=Array(shape=(32,), dtype=dtype('float32'))
1954
- ),
1955
- bias_init=<function zeros at 0x7fb32b98c2c0>,
1956
- conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
1957
- dtype=float32,
1958
- feature_group_count=1,
1959
- in_features=128,
1960
- input_dilation=1,
1961
- kernel=Param( # 2,048 (8.2 KB)
1962
- value=Array(shape=(1, 1, 64, 32), dtype=dtype('float32'))
1963
- ),
1964
- kernel_dilation=1,
1965
- kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
1966
- kernel_shape=(1, 1, 128, 64),
1967
- kernel_size=(1, 1),
1968
- mask=None,
1969
- out_features=64,
1970
- padding=(0, 0),
1971
- param_dtype=float32,
1972
- precision=None,
1973
- promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
1974
- strides=(1, 1),
1975
- use_bias=True
1976
- ),
1977
- norm1=GroupNorm( # Param: 128 (512 B)
1978
- axis_index_groups=None,
1979
- axis_name=None,
1980
- bias=Param( # 64 (256 B)
1981
- value=Array(shape=(64,), dtype=dtype('float32'))
1982
- ),
1983
- bias_init=<function zeros at 0x7fb32b98c2c0>,
1984
- dtype=float32,
1985
- epsilon=1e-06,
1986
- feature_axis=-1,
1987
- group_size=4,
1988
- num_groups=32,
1989
- param_dtype=float32,
1990
- reduction_axes=None,
1991
- scale=Param( # 64 (256 B)
1992
- value=Array(shape=(64,), dtype=dtype('float32'))
1993
- ),
1994
- scale_init=<function ones at 0x7fb32b86e520>,
1995
- use_bias=True,
1996
- use_fast_variance=True,
1997
- use_scale=True
1998
- ),
1999
- norm2=GroupNorm( # Param: 64 (256 B)
2000
- axis_index_groups=None,
2001
- axis_name=None,
2002
- bias=Param( # 32 (128 B)
2003
- value=Array(shape=(32,), dtype=dtype('float32'))
2004
- ),
2005
- bias_init=<function zeros at 0x7fb32b98c2c0>,
2006
- dtype=float32,
2007
- epsilon=1e-06,
2008
- feature_axis=-1,
2009
- group_size=2,
2010
- num_groups=32,
2011
- param_dtype=float32,
2012
- reduction_axes=None,
2013
- scale=Param( # 32 (128 B)
2014
- value=Array(shape=(32,), dtype=dtype('float32'))
2015
- ),
2016
- scale_init=<function ones at 0x7fb32b86e520>,
2017
- use_bias=True,
2018
- use_fast_variance=True,
2019
- use_scale=True
2020
- ),
2021
- out_channels=64,
2022
- time_mlp=Linear( # Param: 8,256 (33.0 KB)
2023
- bias=Param( # 64 (256 B)
2024
- value=Array(shape=(64,), dtype=dtype('float32'))
2025
- ),
2026
- bias_init=<function zeros at 0x7fb32b98c2c0>,
2027
- dot_general=<function dot_general at 0x7fb32c21a3e0>,
2028
- dtype=float32,
2029
- in_features=256,
2030
- kernel=Param( # 8,192 (32.8 KB)
2031
- value=Array(shape=(128, 64), dtype=dtype('float32'))
2032
- ),
2033
- kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
2034
- out_features=128,
2035
- param_dtype=float32,
2036
- precision=None,
2037
- promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
2038
- use_bias=True
2039
- )
2040
- )]
2041
- )],
2042
- )
 
1
  """Model configuration for jcopo/mnist
2
 
 
3
  Training step: 45000
4
  Precision: float32
5
  """
6
 
7
+ import jax
8
  import jax.numpy as jnp
9
  from flax import nnx
10
+ from triax.models import CondUNet2D
11
 
12
+ # Model definition
13
  model = CondUNet2D(
14
+ in_channels=IN_CHANNELS,
15
+ ch=variant["base_channels"],
16
+ ch_mult=variant["channel_multipliers"],
17
+ num_res_blocks=variant["num_res_blocks"],
18
+ attention_resolutions=variant["attention_resolutions"],
19
+ num_heads=variant["num_heads"],
20
+ dropout=DROPOUT,
21
+ dropout_rate=DROPOUT_RATE,
22
+ activation=nnx.swish,
23
+ param_dtype=jnp.float32,
24
+ dtype=compute_dtype,
25
+ rngs=nnx.Rngs(MODEL_SEED),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  )