jcopo commited on
Commit
0a02da1
·
verified ·
1 Parent(s): f7a97bc

Upload model at step 45000

Browse files
Files changed (1) hide show
  1. config.py +2030 -11
config.py CHANGED
@@ -6,18 +6,2037 @@ Precision: float32
6
  """
7
 
8
  from triax.models.nn.condUNet import CondUNet2D
 
 
9
 
10
  # Model architecture
11
- # TODO: Fill in the actual initialization parameters from your training config
12
  model = CondUNet2D(
13
- # Add your model parameters here
14
- # Example:
15
- # hidden_dim=256,
16
- # num_layers=4,
17
- # etc.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  )
19
-
20
- # Metadata
21
- STEP = 45000
22
- PRECISION = "float32"
23
- MODEL_TYPE = "CondUNet2D"
 
6
  """
7
 
8
  from triax.models.nn.condUNet import CondUNet2D
9
+ import jax.numpy as jnp
10
+ from flax import nnx
11
 
12
  # Model architecture
 
13
  model = CondUNet2D(
14
+ blocks_down=[TimestepEmbedSequential( # Param: 320 (1.3 KB)
15
+ layers=[Conv( # Param: 320 (1.3 KB)
16
+ bias=Param( # 32 (128 B)
17
+ value=Array(shape=(32,), dtype=dtype('float32'))
18
+ ),
19
+ bias_init=<function zeros at 0x7fb32b98c2c0>,
20
+ conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
21
+ dtype=float32,
22
+ feature_group_count=1,
23
+ in_features=1,
24
+ input_dilation=1,
25
+ kernel=Param( # 288 (1.2 KB)
26
+ value=Array(shape=(3, 3, 1, 32), dtype=dtype('float32'))
27
+ ),
28
+ kernel_dilation=1,
29
+ kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
30
+ kernel_shape=(3, 3, 1, 64),
31
+ kernel_size=(3, 3),
32
+ mask=None,
33
+ out_features=64,
34
+ padding=(1, 1),
35
+ param_dtype=float32,
36
+ precision=None,
37
+ promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
38
+ strides=(1, 1),
39
+ use_bias=True
40
+ )]
41
+ ), TimestepEmbedSequential( # Param: 26,880 (107.5 KB), RngState: 2 (12 B), Total: 26,882 (107.5 KB)
42
+ layers=[ResnetBlock( # Param: 26,880 (107.5 KB), RngState: 2 (12 B), Total: 26,882 (107.5 KB)
43
+ activation=<PjitFunction of <function silu at 0x7fb32b8440e0>>,
44
+ conv1=Conv( # Param: 9,248 (37.0 KB)
45
+ bias=Param( # 32 (128 B)
46
+ value=Array(shape=(32,), dtype=dtype('float32'))
47
+ ),
48
+ bias_init=<function zeros at 0x7fb32b98c2c0>,
49
+ conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
50
+ dtype=float32,
51
+ feature_group_count=1,
52
+ in_features=64,
53
+ input_dilation=1,
54
+ kernel=Param( # 9,216 (36.9 KB)
55
+ value=Array(shape=(3, 3, 32, 32), dtype=dtype('float32'))
56
+ ),
57
+ kernel_dilation=1,
58
+ kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
59
+ kernel_shape=(3, 3, 64, 64),
60
+ kernel_size=(3, 3),
61
+ mask=None,
62
+ out_features=64,
63
+ padding=(1, 1),
64
+ param_dtype=float32,
65
+ precision=None,
66
+ promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
67
+ strides=(1, 1),
68
+ use_bias=True
69
+ ),
70
+ conv2=Conv( # Param: 9,248 (37.0 KB)
71
+ bias=Param( # 32 (128 B)
72
+ value=Array(shape=(32,), dtype=dtype('float32'))
73
+ ),
74
+ bias_init=<function zeros at 0x7fb32b98c2c0>,
75
+ conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
76
+ dtype=float32,
77
+ feature_group_count=1,
78
+ in_features=64,
79
+ input_dilation=1,
80
+ kernel=Param( # 9,216 (36.9 KB)
81
+ value=Array(shape=(3, 3, 32, 32), dtype=dtype('float32'))
82
+ ),
83
+ kernel_dilation=1,
84
+ kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
85
+ kernel_shape=(3, 3, 64, 64),
86
+ kernel_size=(3, 3),
87
+ mask=None,
88
+ out_features=64,
89
+ padding=(1, 1),
90
+ param_dtype=float32,
91
+ precision=None,
92
+ promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
93
+ strides=(1, 1),
94
+ use_bias=True
95
+ ),
96
+ dropout=Dropout( # RngState: 2 (12 B)
97
+ broadcast_dims=(),
98
+ deterministic=False,
99
+ rate=0.1,
100
+ rng_collection='dropout',
101
+ rngs=RngStream( # RngState: 2 (12 B)
102
+ count=RngCount( # 1 (4 B)
103
+ value=Array(0, dtype=uint32),
104
+ tag='default'
105
+ ),
106
+ key=RngKey( # 1 (8 B)
107
+ value=Array((), dtype=key<fry>) overlaying:
108
+ [2585633080 2083471411],
109
+ tag='default'
110
+ ),
111
+ tag='default'
112
+ )
113
+ ),
114
+ embedding_dim=256,
115
+ in_channels=64,
116
+ norm1=GroupNorm( # Param: 64 (256 B)
117
+ axis_index_groups=None,
118
+ axis_name=None,
119
+ bias=Param( # 32 (128 B)
120
+ value=Array(shape=(32,), dtype=dtype('float32'))
121
+ ),
122
+ bias_init=<function zeros at 0x7fb32b98c2c0>,
123
+ dtype=float32,
124
+ epsilon=1e-06,
125
+ feature_axis=-1,
126
+ group_size=2,
127
+ num_groups=32,
128
+ param_dtype=float32,
129
+ reduction_axes=None,
130
+ scale=Param( # 32 (128 B)
131
+ value=Array(shape=(32,), dtype=dtype('float32'))
132
+ ),
133
+ scale_init=<function ones at 0x7fb32b86e520>,
134
+ use_bias=True,
135
+ use_fast_variance=True,
136
+ use_scale=True
137
+ ),
138
+ norm2=GroupNorm( # Param: 64 (256 B)
139
+ axis_index_groups=None,
140
+ axis_name=None,
141
+ bias=Param( # 32 (128 B)
142
+ value=Array(shape=(32,), dtype=dtype('float32'))
143
+ ),
144
+ bias_init=<function zeros at 0x7fb32b98c2c0>,
145
+ dtype=float32,
146
+ epsilon=1e-06,
147
+ feature_axis=-1,
148
+ group_size=2,
149
+ num_groups=32,
150
+ param_dtype=float32,
151
+ reduction_axes=None,
152
+ scale=Param( # 32 (128 B)
153
+ value=Array(shape=(32,), dtype=dtype('float32'))
154
+ ),
155
+ scale_init=<function ones at 0x7fb32b86e520>,
156
+ use_bias=True,
157
+ use_fast_variance=True,
158
+ use_scale=True
159
+ ),
160
+ out_channels=64,
161
+ time_mlp=Linear( # Param: 8,256 (33.0 KB)
162
+ bias=Param( # 64 (256 B)
163
+ value=Array(shape=(64,), dtype=dtype('float32'))
164
+ ),
165
+ bias_init=<function zeros at 0x7fb32b98c2c0>,
166
+ dot_general=<function dot_general at 0x7fb32c21a3e0>,
167
+ dtype=float32,
168
+ in_features=256,
169
+ kernel=Param( # 8,192 (32.8 KB)
170
+ value=Array(shape=(128, 64), dtype=dtype('float32'))
171
+ ),
172
+ kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
173
+ out_features=128,
174
+ param_dtype=float32,
175
+ precision=None,
176
+ promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
177
+ use_bias=True
178
+ )
179
+ )]
180
+ ), TimestepEmbedSequential( # Param: 9,248 (37.0 KB)
181
+ layers=[Downsample( # Param: 9,248 (37.0 KB)
182
+ conv=Conv( # Param: 9,248 (37.0 KB)
183
+ bias=Param( # 32 (128 B)
184
+ value=Array(shape=(32,), dtype=dtype('float32'))
185
+ ),
186
+ bias_init=<function zeros at 0x7fb32b98c2c0>,
187
+ conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
188
+ dtype=float32,
189
+ feature_group_count=1,
190
+ in_features=64,
191
+ input_dilation=1,
192
+ kernel=Param( # 9,216 (36.9 KB)
193
+ value=Array(shape=(3, 3, 32, 32), dtype=dtype('float32'))
194
+ ),
195
+ kernel_dilation=1,
196
+ kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
197
+ kernel_shape=(3, 3, 64, 64),
198
+ kernel_size=(3, 3),
199
+ mask=None,
200
+ out_features=64,
201
+ padding=(0, 0),
202
+ param_dtype=float32,
203
+ precision=None,
204
+ promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
205
+ strides=(2, 2),
206
+ use_bias=True
207
+ ),
208
+ method='conv'
209
+ )]
210
+ ), TimestepEmbedSequential( # Param: 91,008 (364.0 KB), RngState: 2 (12 B), Total: 91,010 (364.0 KB)
211
+ layers=[ResnetBlock( # Param: 74,240 (297.0 KB), RngState: 2 (12 B), Total: 74,242 (297.0 KB)
212
+ activation=<PjitFunction of <function silu at 0x7fb32b8440e0>>,
213
+ conv1=Conv( # Param: 18,496 (74.0 KB)
214
+ bias=Param( # 64 (256 B)
215
+ value=Array(shape=(64,), dtype=dtype('float32'))
216
+ ),
217
+ bias_init=<function zeros at 0x7fb32b98c2c0>,
218
+ conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
219
+ dtype=float32,
220
+ feature_group_count=1,
221
+ in_features=64,
222
+ input_dilation=1,
223
+ kernel=Param( # 18,432 (73.7 KB)
224
+ value=Array(shape=(3, 3, 32, 64), dtype=dtype('float32'))
225
+ ),
226
+ kernel_dilation=1,
227
+ kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
228
+ kernel_shape=(3, 3, 64, 128),
229
+ kernel_size=(3, 3),
230
+ mask=None,
231
+ out_features=128,
232
+ padding=(1, 1),
233
+ param_dtype=float32,
234
+ precision=None,
235
+ promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
236
+ strides=(1, 1),
237
+ use_bias=True
238
+ ),
239
+ conv2=Conv( # Param: 36,928 (147.7 KB)
240
+ bias=Param( # 64 (256 B)
241
+ value=Array(shape=(64,), dtype=dtype('float32'))
242
+ ),
243
+ bias_init=<function zeros at 0x7fb32b98c2c0>,
244
+ conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
245
+ dtype=float32,
246
+ feature_group_count=1,
247
+ in_features=128,
248
+ input_dilation=1,
249
+ kernel=Param( # 36,864 (147.5 KB)
250
+ value=Array(shape=(3, 3, 64, 64), dtype=dtype('float32'))
251
+ ),
252
+ kernel_dilation=1,
253
+ kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
254
+ kernel_shape=(3, 3, 128, 128),
255
+ kernel_size=(3, 3),
256
+ mask=None,
257
+ out_features=128,
258
+ padding=(1, 1),
259
+ param_dtype=float32,
260
+ precision=None,
261
+ promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
262
+ strides=(1, 1),
263
+ use_bias=True
264
+ ),
265
+ dropout=Dropout( # RngState: 2 (12 B)
266
+ broadcast_dims=(),
267
+ deterministic=False,
268
+ rate=0.1,
269
+ rng_collection='dropout',
270
+ rngs=RngStream( # RngState: 2 (12 B)
271
+ count=RngCount( # 1 (4 B)
272
+ value=Array(0, dtype=uint32),
273
+ tag='default'
274
+ ),
275
+ key=RngKey( # 1 (8 B)
276
+ value=Array((), dtype=key<fry>) overlaying:
277
+ [2656139193 2766658851],
278
+ tag='default'
279
+ ),
280
+ tag='default'
281
+ )
282
+ ),
283
+ embedding_dim=256,
284
+ in_channels=64,
285
+ nin_shortcut=Conv( # Param: 2,112 (8.4 KB)
286
+ bias=Param( # 64 (256 B)
287
+ value=Array(shape=(64,), dtype=dtype('float32'))
288
+ ),
289
+ bias_init=<function zeros at 0x7fb32b98c2c0>,
290
+ conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
291
+ dtype=float32,
292
+ feature_group_count=1,
293
+ in_features=64,
294
+ input_dilation=1,
295
+ kernel=Param( # 2,048 (8.2 KB)
296
+ value=Array(shape=(1, 1, 32, 64), dtype=dtype('float32'))
297
+ ),
298
+ kernel_dilation=1,
299
+ kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
300
+ kernel_shape=(1, 1, 64, 128),
301
+ kernel_size=(1, 1),
302
+ mask=None,
303
+ out_features=128,
304
+ padding=(0, 0),
305
+ param_dtype=float32,
306
+ precision=None,
307
+ promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
308
+ strides=(1, 1),
309
+ use_bias=True
310
+ ),
311
+ norm1=GroupNorm( # Param: 64 (256 B)
312
+ axis_index_groups=None,
313
+ axis_name=None,
314
+ bias=Param( # 32 (128 B)
315
+ value=Array(shape=(32,), dtype=dtype('float32'))
316
+ ),
317
+ bias_init=<function zeros at 0x7fb32b98c2c0>,
318
+ dtype=float32,
319
+ epsilon=1e-06,
320
+ feature_axis=-1,
321
+ group_size=2,
322
+ num_groups=32,
323
+ param_dtype=float32,
324
+ reduction_axes=None,
325
+ scale=Param( # 32 (128 B)
326
+ value=Array(shape=(32,), dtype=dtype('float32'))
327
+ ),
328
+ scale_init=<function ones at 0x7fb32b86e520>,
329
+ use_bias=True,
330
+ use_fast_variance=True,
331
+ use_scale=True
332
+ ),
333
+ norm2=GroupNorm( # Param: 128 (512 B)
334
+ axis_index_groups=None,
335
+ axis_name=None,
336
+ bias=Param( # 64 (256 B)
337
+ value=Array(shape=(64,), dtype=dtype('float32'))
338
+ ),
339
+ bias_init=<function zeros at 0x7fb32b98c2c0>,
340
+ dtype=float32,
341
+ epsilon=1e-06,
342
+ feature_axis=-1,
343
+ group_size=4,
344
+ num_groups=32,
345
+ param_dtype=float32,
346
+ reduction_axes=None,
347
+ scale=Param( # 64 (256 B)
348
+ value=Array(shape=(64,), dtype=dtype('float32'))
349
+ ),
350
+ scale_init=<function ones at 0x7fb32b86e520>,
351
+ use_bias=True,
352
+ use_fast_variance=True,
353
+ use_scale=True
354
+ ),
355
+ out_channels=128,
356
+ time_mlp=Linear( # Param: 16,512 (66.0 KB)
357
+ bias=Param( # 128 (512 B)
358
+ value=Array(shape=(128,), dtype=dtype('float32'))
359
+ ),
360
+ bias_init=<function zeros at 0x7fb32b98c2c0>,
361
+ dot_general=<function dot_general at 0x7fb32c21a3e0>,
362
+ dtype=float32,
363
+ in_features=256,
364
+ kernel=Param( # 16,384 (65.5 KB)
365
+ value=Array(shape=(128, 128), dtype=dtype('float32'))
366
+ ),
367
+ kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
368
+ out_features=256,
369
+ param_dtype=float32,
370
+ precision=None,
371
+ promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
372
+ use_bias=True
373
+ )
374
+ ), AttnBlock( # Param: 16,768 (67.1 KB)
375
+ dtype=float32,
376
+ head_dim=32,
377
+ k=Conv( # Param: 4,160 (16.6 KB)
378
+ bias=Param( # 64 (256 B)
379
+ value=Array(shape=(64,), dtype=dtype('float32'))
380
+ ),
381
+ bias_init=<function zeros at 0x7fb32b98c2c0>,
382
+ conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
383
+ dtype=float32,
384
+ feature_group_count=1,
385
+ in_features=128,
386
+ input_dilation=1,
387
+ kernel=Param( # 4,096 (16.4 KB)
388
+ value=Array(shape=(1, 1, 64, 64), dtype=dtype('float32'))
389
+ ),
390
+ kernel_dilation=1,
391
+ kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
392
+ kernel_shape=(1, 1, 128, 128),
393
+ kernel_size=(1, 1),
394
+ mask=None,
395
+ out_features=128,
396
+ padding='SAME',
397
+ param_dtype=float32,
398
+ precision=None,
399
+ promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
400
+ strides=1,
401
+ use_bias=True
402
+ ),
403
+ norm=GroupNorm( # Param: 128 (512 B)
404
+ axis_index_groups=None,
405
+ axis_name=None,
406
+ bias=Param( # 64 (256 B)
407
+ value=Array(shape=(64,), dtype=dtype('float32'))
408
+ ),
409
+ bias_init=<function zeros at 0x7fb32b98c2c0>,
410
+ dtype=float32,
411
+ epsilon=1e-06,
412
+ feature_axis=-1,
413
+ group_size=4,
414
+ num_groups=32,
415
+ param_dtype=float32,
416
+ reduction_axes=None,
417
+ scale=Param( # 64 (256 B)
418
+ value=Array(shape=(64,), dtype=dtype('float32'))
419
+ ),
420
+ scale_init=<function ones at 0x7fb32b86e520>,
421
+ use_bias=True,
422
+ use_fast_variance=True,
423
+ use_scale=True
424
+ ),
425
+ num_heads=4,
426
+ proj_out=Conv( # Param: 4,160 (16.6 KB)
427
+ bias=Param( # 64 (256 B)
428
+ value=Array(shape=(64,), dtype=dtype('float32'))
429
+ ),
430
+ bias_init=<function zeros at 0x7fb32b98c2c0>,
431
+ conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
432
+ dtype=float32,
433
+ feature_group_count=1,
434
+ in_features=128,
435
+ input_dilation=1,
436
+ kernel=Param( # 4,096 (16.4 KB)
437
+ value=Array(shape=(1, 1, 64, 64), dtype=dtype('float32'))
438
+ ),
439
+ kernel_dilation=1,
440
+ kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
441
+ kernel_shape=(1, 1, 128, 128),
442
+ kernel_size=(1, 1),
443
+ mask=None,
444
+ out_features=128,
445
+ padding='SAME',
446
+ param_dtype=float32,
447
+ precision=None,
448
+ promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
449
+ strides=1,
450
+ use_bias=True
451
+ ),
452
+ q=Conv( # Param: 4,160 (16.6 KB)
453
+ bias=Param( # 64 (256 B)
454
+ value=Array(shape=(64,), dtype=dtype('float32'))
455
+ ),
456
+ bias_init=<function zeros at 0x7fb32b98c2c0>,
457
+ conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
458
+ dtype=float32,
459
+ feature_group_count=1,
460
+ in_features=128,
461
+ input_dilation=1,
462
+ kernel=Param( # 4,096 (16.4 KB)
463
+ value=Array(shape=(1, 1, 64, 64), dtype=dtype('float32'))
464
+ ),
465
+ kernel_dilation=1,
466
+ kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
467
+ kernel_shape=(1, 1, 128, 128),
468
+ kernel_size=(1, 1),
469
+ mask=None,
470
+ out_features=128,
471
+ padding='SAME',
472
+ param_dtype=float32,
473
+ precision=None,
474
+ promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
475
+ strides=1,
476
+ use_bias=True
477
+ ),
478
+ v=Conv( # Param: 4,160 (16.6 KB)
479
+ bias=Param( # 64 (256 B)
480
+ value=Array(shape=(64,), dtype=dtype('float32'))
481
+ ),
482
+ bias_init=<function zeros at 0x7fb32b98c2c0>,
483
+ conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
484
+ dtype=float32,
485
+ feature_group_count=1,
486
+ in_features=128,
487
+ input_dilation=1,
488
+ kernel=Param( # 4,096 (16.4 KB)
489
+ value=Array(shape=(1, 1, 64, 64), dtype=dtype('float32'))
490
+ ),
491
+ kernel_dilation=1,
492
+ kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
493
+ kernel_shape=(1, 1, 128, 128),
494
+ kernel_size=(1, 1),
495
+ mask=None,
496
+ out_features=128,
497
+ padding='SAME',
498
+ param_dtype=float32,
499
+ precision=None,
500
+ promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
501
+ strides=1,
502
+ use_bias=True
503
+ )
504
+ )]
505
+ ), TimestepEmbedSequential( # Param: 36,928 (147.7 KB)
506
+ layers=[Downsample( # Param: 36,928 (147.7 KB)
507
+ conv=Conv( # Param: 36,928 (147.7 KB)
508
+ bias=Param( # 64 (256 B)
509
+ value=Array(shape=(64,), dtype=dtype('float32'))
510
+ ),
511
+ bias_init=<function zeros at 0x7fb32b98c2c0>,
512
+ conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
513
+ dtype=float32,
514
+ feature_group_count=1,
515
+ in_features=128,
516
+ input_dilation=1,
517
+ kernel=Param( # 36,864 (147.5 KB)
518
+ value=Array(shape=(3, 3, 64, 64), dtype=dtype('float32'))
519
+ ),
520
+ kernel_dilation=1,
521
+ kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
522
+ kernel_shape=(3, 3, 128, 128),
523
+ kernel_size=(3, 3),
524
+ mask=None,
525
+ out_features=128,
526
+ padding=(0, 0),
527
+ param_dtype=float32,
528
+ precision=None,
529
+ promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
530
+ strides=(2, 2),
531
+ use_bias=True
532
+ ),
533
+ method='conv'
534
+ )]
535
+ ), TimestepEmbedSequential( # Param: 90,624 (362.5 KB), RngState: 2 (12 B), Total: 90,626 (362.5 KB)
536
+ layers=[ResnetBlock( # Param: 90,624 (362.5 KB), RngState: 2 (12 B), Total: 90,626 (362.5 KB)
537
+ activation=<PjitFunction of <function silu at 0x7fb32b8440e0>>,
538
+ conv1=Conv( # Param: 36,928 (147.7 KB)
539
+ bias=Param( # 64 (256 B)
540
+ value=Array(shape=(64,), dtype=dtype('float32'))
541
+ ),
542
+ bias_init=<function zeros at 0x7fb32b98c2c0>,
543
+ conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
544
+ dtype=float32,
545
+ feature_group_count=1,
546
+ in_features=128,
547
+ input_dilation=1,
548
+ kernel=Param( # 36,864 (147.5 KB)
549
+ value=Array(shape=(3, 3, 64, 64), dtype=dtype('float32'))
550
+ ),
551
+ kernel_dilation=1,
552
+ kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
553
+ kernel_shape=(3, 3, 128, 128),
554
+ kernel_size=(3, 3),
555
+ mask=None,
556
+ out_features=128,
557
+ padding=(1, 1),
558
+ param_dtype=float32,
559
+ precision=None,
560
+ promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
561
+ strides=(1, 1),
562
+ use_bias=True
563
+ ),
564
+ conv2=Conv( # Param: 36,928 (147.7 KB)
565
+ bias=Param( # 64 (256 B)
566
+ value=Array(shape=(64,), dtype=dtype('float32'))
567
+ ),
568
+ bias_init=<function zeros at 0x7fb32b98c2c0>,
569
+ conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
570
+ dtype=float32,
571
+ feature_group_count=1,
572
+ in_features=128,
573
+ input_dilation=1,
574
+ kernel=Param( # 36,864 (147.5 KB)
575
+ value=Array(shape=(3, 3, 64, 64), dtype=dtype('float32'))
576
+ ),
577
+ kernel_dilation=1,
578
+ kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
579
+ kernel_shape=(3, 3, 128, 128),
580
+ kernel_size=(3, 3),
581
+ mask=None,
582
+ out_features=128,
583
+ padding=(1, 1),
584
+ param_dtype=float32,
585
+ precision=None,
586
+ promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
587
+ strides=(1, 1),
588
+ use_bias=True
589
+ ),
590
+ dropout=Dropout( # RngState: 2 (12 B)
591
+ broadcast_dims=(),
592
+ deterministic=False,
593
+ rate=0.1,
594
+ rng_collection='dropout',
595
+ rngs=RngStream( # RngState: 2 (12 B)
596
+ count=RngCount( # 1 (4 B)
597
+ value=Array(0, dtype=uint32),
598
+ tag='default'
599
+ ),
600
+ key=RngKey( # 1 (8 B)
601
+ value=Array((), dtype=key<fry>) overlaying:
602
+ [ 692128146 1043829861],
603
+ tag='default'
604
+ ),
605
+ tag='default'
606
+ )
607
+ ),
608
+ embedding_dim=256,
609
+ in_channels=128,
610
+ norm1=GroupNorm( # Param: 128 (512 B)
611
+ axis_index_groups=None,
612
+ axis_name=None,
613
+ bias=Param( # 64 (256 B)
614
+ value=Array(shape=(64,), dtype=dtype('float32'))
615
+ ),
616
+ bias_init=<function zeros at 0x7fb32b98c2c0>,
617
+ dtype=float32,
618
+ epsilon=1e-06,
619
+ feature_axis=-1,
620
+ group_size=4,
621
+ num_groups=32,
622
+ param_dtype=float32,
623
+ reduction_axes=None,
624
+ scale=Param( # 64 (256 B)
625
+ value=Array(shape=(64,), dtype=dtype('float32'))
626
+ ),
627
+ scale_init=<function ones at 0x7fb32b86e520>,
628
+ use_bias=True,
629
+ use_fast_variance=True,
630
+ use_scale=True
631
+ ),
632
+ norm2=GroupNorm( # Param: 128 (512 B)
633
+ axis_index_groups=None,
634
+ axis_name=None,
635
+ bias=Param( # 64 (256 B)
636
+ value=Array(shape=(64,), dtype=dtype('float32'))
637
+ ),
638
+ bias_init=<function zeros at 0x7fb32b98c2c0>,
639
+ dtype=float32,
640
+ epsilon=1e-06,
641
+ feature_axis=-1,
642
+ group_size=4,
643
+ num_groups=32,
644
+ param_dtype=float32,
645
+ reduction_axes=None,
646
+ scale=Param( # 64 (256 B)
647
+ value=Array(shape=(64,), dtype=dtype('float32'))
648
+ ),
649
+ scale_init=<function ones at 0x7fb32b86e520>,
650
+ use_bias=True,
651
+ use_fast_variance=True,
652
+ use_scale=True
653
+ ),
654
+ out_channels=128,
655
+ time_mlp=Linear( # Param: 16,512 (66.0 KB)
656
+ bias=Param( # 128 (512 B)
657
+ value=Array(shape=(128,), dtype=dtype('float32'))
658
+ ),
659
+ bias_init=<function zeros at 0x7fb32b98c2c0>,
660
+ dot_general=<function dot_general at 0x7fb32c21a3e0>,
661
+ dtype=float32,
662
+ in_features=256,
663
+ kernel=Param( # 16,384 (65.5 KB)
664
+ value=Array(shape=(128, 128), dtype=dtype('float32'))
665
+ ),
666
+ kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
667
+ out_features=256,
668
+ param_dtype=float32,
669
+ precision=None,
670
+ promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
671
+ use_bias=True
672
+ )
673
+ )]
674
+ )],
675
+ blocks_up=[TimestepEmbedSequential( # Param: 135,872 (543.5 KB), RngState: 2 (12 B), Total: 135,874 (543.5 KB)
676
+ layers=[ResnetBlock( # Param: 135,872 (543.5 KB), RngState: 2 (12 B), Total: 135,874 (543.5 KB)
677
+ activation=<PjitFunction of <function silu at 0x7fb32b8440e0>>,
678
+ conv1=Conv( # Param: 73,792 (295.2 KB)
679
+ bias=Param( # 64 (256 B)
680
+ value=Array(shape=(64,), dtype=dtype('float32'))
681
+ ),
682
+ bias_init=<function zeros at 0x7fb32b98c2c0>,
683
+ conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
684
+ dtype=float32,
685
+ feature_group_count=1,
686
+ in_features=256,
687
+ input_dilation=1,
688
+ kernel=Param( # 73,728 (294.9 KB)
689
+ value=Array(shape=(3, 3, 128, 64), dtype=dtype('float32'))
690
+ ),
691
+ kernel_dilation=1,
692
+ kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
693
+ kernel_shape=(3, 3, 256, 128),
694
+ kernel_size=(3, 3),
695
+ mask=None,
696
+ out_features=128,
697
+ padding=(1, 1),
698
+ param_dtype=float32,
699
+ precision=None,
700
+ promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
701
+ strides=(1, 1),
702
+ use_bias=True
703
+ ),
704
+ conv2=Conv( # Param: 36,928 (147.7 KB)
705
+ bias=Param( # 64 (256 B)
706
+ value=Array(shape=(64,), dtype=dtype('float32'))
707
+ ),
708
+ bias_init=<function zeros at 0x7fb32b98c2c0>,
709
+ conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
710
+ dtype=float32,
711
+ feature_group_count=1,
712
+ in_features=128,
713
+ input_dilation=1,
714
+ kernel=Param( # 36,864 (147.5 KB)
715
+ value=Array(shape=(3, 3, 64, 64), dtype=dtype('float32'))
716
+ ),
717
+ kernel_dilation=1,
718
+ kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
719
+ kernel_shape=(3, 3, 128, 128),
720
+ kernel_size=(3, 3),
721
+ mask=None,
722
+ out_features=128,
723
+ padding=(1, 1),
724
+ param_dtype=float32,
725
+ precision=None,
726
+ promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
727
+ strides=(1, 1),
728
+ use_bias=True
729
+ ),
730
+ dropout=Dropout( # RngState: 2 (12 B)
731
+ broadcast_dims=(),
732
+ deterministic=False,
733
+ rate=0.1,
734
+ rng_collection='dropout',
735
+ rngs=RngStream( # RngState: 2 (12 B)
736
+ count=RngCount( # 1 (4 B)
737
+ value=Array(0, dtype=uint32),
738
+ tag='default'
739
+ ),
740
+ key=RngKey( # 1 (8 B)
741
+ value=Array((), dtype=key<fry>) overlaying:
742
+ [2853902436 2217684095],
743
+ tag='default'
744
+ ),
745
+ tag='default'
746
+ )
747
+ ),
748
+ embedding_dim=256,
749
+ in_channels=256,
750
+ nin_shortcut=Conv( # Param: 8,256 (33.0 KB)
751
+ bias=Param( # 64 (256 B)
752
+ value=Array(shape=(64,), dtype=dtype('float32'))
753
+ ),
754
+ bias_init=<function zeros at 0x7fb32b98c2c0>,
755
+ conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
756
+ dtype=float32,
757
+ feature_group_count=1,
758
+ in_features=256,
759
+ input_dilation=1,
760
+ kernel=Param( # 8,192 (32.8 KB)
761
+ value=Array(shape=(1, 1, 128, 64), dtype=dtype('float32'))
762
+ ),
763
+ kernel_dilation=1,
764
+ kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
765
+ kernel_shape=(1, 1, 256, 128),
766
+ kernel_size=(1, 1),
767
+ mask=None,
768
+ out_features=128,
769
+ padding=(0, 0),
770
+ param_dtype=float32,
771
+ precision=None,
772
+ promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
773
+ strides=(1, 1),
774
+ use_bias=True
775
+ ),
776
+ norm1=GroupNorm( # Param: 256 (1.0 KB)
777
+ axis_index_groups=None,
778
+ axis_name=None,
779
+ bias=Param( # 128 (512 B)
780
+ value=Array(shape=(128,), dtype=dtype('float32'))
781
+ ),
782
+ bias_init=<function zeros at 0x7fb32b98c2c0>,
783
+ dtype=float32,
784
+ epsilon=1e-06,
785
+ feature_axis=-1,
786
+ group_size=8,
787
+ num_groups=32,
788
+ param_dtype=float32,
789
+ reduction_axes=None,
790
+ scale=Param( # 128 (512 B)
791
+ value=Array(shape=(128,), dtype=dtype('float32'))
792
+ ),
793
+ scale_init=<function ones at 0x7fb32b86e520>,
794
+ use_bias=True,
795
+ use_fast_variance=True,
796
+ use_scale=True
797
+ ),
798
+ norm2=GroupNorm( # Param: 128 (512 B)
799
+ axis_index_groups=None,
800
+ axis_name=None,
801
+ bias=Param( # 64 (256 B)
802
+ value=Array(shape=(64,), dtype=dtype('float32'))
803
+ ),
804
+ bias_init=<function zeros at 0x7fb32b98c2c0>,
805
+ dtype=float32,
806
+ epsilon=1e-06,
807
+ feature_axis=-1,
808
+ group_size=4,
809
+ num_groups=32,
810
+ param_dtype=float32,
811
+ reduction_axes=None,
812
+ scale=Param( # 64 (256 B)
813
+ value=Array(shape=(64,), dtype=dtype('float32'))
814
+ ),
815
+ scale_init=<function ones at 0x7fb32b86e520>,
816
+ use_bias=True,
817
+ use_fast_variance=True,
818
+ use_scale=True
819
+ ),
820
+ out_channels=128,
821
+ time_mlp=Linear( # Param: 16,512 (66.0 KB)
822
+ bias=Param( # 128 (512 B)
823
+ value=Array(shape=(128,), dtype=dtype('float32'))
824
+ ),
825
+ bias_init=<function zeros at 0x7fb32b98c2c0>,
826
+ dot_general=<function dot_general at 0x7fb32c21a3e0>,
827
+ dtype=float32,
828
+ in_features=256,
829
+ kernel=Param( # 16,384 (65.5 KB)
830
+ value=Array(shape=(128, 128), dtype=dtype('float32'))
831
+ ),
832
+ kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
833
+ out_features=256,
834
+ param_dtype=float32,
835
+ precision=None,
836
+ promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
837
+ use_bias=True
838
+ )
839
+ )]
840
+ ), TimestepEmbedSequential( # Param: 320,512 (1.3 MB), RngState: 2 (12 B), Total: 320,514 (1.3 MB)
841
+ layers=[ResnetBlock( # Param: 135,872 (543.5 KB), RngState: 2 (12 B), Total: 135,874 (543.5 KB)
842
+ activation=<PjitFunction of <function silu at 0x7fb32b8440e0>>,
843
+ conv1=Conv( # Param: 73,792 (295.2 KB)
844
+ bias=Param( # 64 (256 B)
845
+ value=Array(shape=(64,), dtype=dtype('float32'))
846
+ ),
847
+ bias_init=<function zeros at 0x7fb32b98c2c0>,
848
+ conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
849
+ dtype=float32,
850
+ feature_group_count=1,
851
+ in_features=256,
852
+ input_dilation=1,
853
+ kernel=Param( # 73,728 (294.9 KB)
854
+ value=Array(shape=(3, 3, 128, 64), dtype=dtype('float32'))
855
+ ),
856
+ kernel_dilation=1,
857
+ kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
858
+ kernel_shape=(3, 3, 256, 128),
859
+ kernel_size=(3, 3),
860
+ mask=None,
861
+ out_features=128,
862
+ padding=(1, 1),
863
+ param_dtype=float32,
864
+ precision=None,
865
+ promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
866
+ strides=(1, 1),
867
+ use_bias=True
868
+ ),
869
+ conv2=Conv( # Param: 36,928 (147.7 KB)
870
+ bias=Param( # 64 (256 B)
871
+ value=Array(shape=(64,), dtype=dtype('float32'))
872
+ ),
873
+ bias_init=<function zeros at 0x7fb32b98c2c0>,
874
+ conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
875
+ dtype=float32,
876
+ feature_group_count=1,
877
+ in_features=128,
878
+ input_dilation=1,
879
+ kernel=Param( # 36,864 (147.5 KB)
880
+ value=Array(shape=(3, 3, 64, 64), dtype=dtype('float32'))
881
+ ),
882
+ kernel_dilation=1,
883
+ kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
884
+ kernel_shape=(3, 3, 128, 128),
885
+ kernel_size=(3, 3),
886
+ mask=None,
887
+ out_features=128,
888
+ padding=(1, 1),
889
+ param_dtype=float32,
890
+ precision=None,
891
+ promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
892
+ strides=(1, 1),
893
+ use_bias=True
894
+ ),
895
+ dropout=Dropout( # RngState: 2 (12 B)
896
+ broadcast_dims=(),
897
+ deterministic=False,
898
+ rate=0.1,
899
+ rng_collection='dropout',
900
+ rngs=RngStream( # RngState: 2 (12 B)
901
+ count=RngCount( # 1 (4 B)
902
+ value=Array(0, dtype=uint32),
903
+ tag='default'
904
+ ),
905
+ key=RngKey( # 1 (8 B)
906
+ value=Array((), dtype=key<fry>) overlaying:
907
+ [2785098898 841100811],
908
+ tag='default'
909
+ ),
910
+ tag='default'
911
+ )
912
+ ),
913
+ embedding_dim=256,
914
+ in_channels=256,
915
+ nin_shortcut=Conv( # Param: 8,256 (33.0 KB)
916
+ bias=Param( # 64 (256 B)
917
+ value=Array(shape=(64,), dtype=dtype('float32'))
918
+ ),
919
+ bias_init=<function zeros at 0x7fb32b98c2c0>,
920
+ conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
921
+ dtype=float32,
922
+ feature_group_count=1,
923
+ in_features=256,
924
+ input_dilation=1,
925
+ kernel=Param( # 8,192 (32.8 KB)
926
+ value=Array(shape=(1, 1, 128, 64), dtype=dtype('float32'))
927
+ ),
928
+ kernel_dilation=1,
929
+ kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
930
+ kernel_shape=(1, 1, 256, 128),
931
+ kernel_size=(1, 1),
932
+ mask=None,
933
+ out_features=128,
934
+ padding=(0, 0),
935
+ param_dtype=float32,
936
+ precision=None,
937
+ promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
938
+ strides=(1, 1),
939
+ use_bias=True
940
+ ),
941
+ norm1=GroupNorm( # Param: 256 (1.0 KB)
942
+ axis_index_groups=None,
943
+ axis_name=None,
944
+ bias=Param( # 128 (512 B)
945
+ value=Array(shape=(128,), dtype=dtype('float32'))
946
+ ),
947
+ bias_init=<function zeros at 0x7fb32b98c2c0>,
948
+ dtype=float32,
949
+ epsilon=1e-06,
950
+ feature_axis=-1,
951
+ group_size=8,
952
+ num_groups=32,
953
+ param_dtype=float32,
954
+ reduction_axes=None,
955
+ scale=Param( # 128 (512 B)
956
+ value=Array(shape=(128,), dtype=dtype('float32'))
957
+ ),
958
+ scale_init=<function ones at 0x7fb32b86e520>,
959
+ use_bias=True,
960
+ use_fast_variance=True,
961
+ use_scale=True
962
+ ),
963
+ norm2=GroupNorm( # Param: 128 (512 B)
964
+ axis_index_groups=None,
965
+ axis_name=None,
966
+ bias=Param( # 64 (256 B)
967
+ value=Array(shape=(64,), dtype=dtype('float32'))
968
+ ),
969
+ bias_init=<function zeros at 0x7fb32b98c2c0>,
970
+ dtype=float32,
971
+ epsilon=1e-06,
972
+ feature_axis=-1,
973
+ group_size=4,
974
+ num_groups=32,
975
+ param_dtype=float32,
976
+ reduction_axes=None,
977
+ scale=Param( # 64 (256 B)
978
+ value=Array(shape=(64,), dtype=dtype('float32'))
979
+ ),
980
+ scale_init=<function ones at 0x7fb32b86e520>,
981
+ use_bias=True,
982
+ use_fast_variance=True,
983
+ use_scale=True
984
+ ),
985
+ out_channels=128,
986
+ time_mlp=Linear( # Param: 16,512 (66.0 KB)
987
+ bias=Param( # 128 (512 B)
988
+ value=Array(shape=(128,), dtype=dtype('float32'))
989
+ ),
990
+ bias_init=<function zeros at 0x7fb32b98c2c0>,
991
+ dot_general=<function dot_general at 0x7fb32c21a3e0>,
992
+ dtype=float32,
993
+ in_features=256,
994
+ kernel=Param( # 16,384 (65.5 KB)
995
+ value=Array(shape=(128, 128), dtype=dtype('float32'))
996
+ ),
997
+ kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
998
+ out_features=256,
999
+ param_dtype=float32,
1000
+ precision=None,
1001
+ promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
1002
+ use_bias=True
1003
+ )
1004
+ ), Upsample( # Param: 184,640 (738.6 KB)
1005
+ conv_pixel_shuffle=Conv( # Param: 147,712 (590.8 KB)
1006
+ bias=Param( # 256 (1.0 KB)
1007
+ value=Array(shape=(256,), dtype=dtype('float32'))
1008
+ ),
1009
+ bias_init=<function zeros at 0x7fb32b98c2c0>,
1010
+ conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
1011
+ dtype=float32,
1012
+ feature_group_count=1,
1013
+ in_features=128,
1014
+ input_dilation=1,
1015
+ kernel=Param( # 147,456 (589.8 KB)
1016
+ value=Array(shape=(3, 3, 64, 256), dtype=dtype('float32'))
1017
+ ),
1018
+ kernel_dilation=1,
1019
+ kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
1020
+ kernel_shape=(3, 3, 128, 512),
1021
+ kernel_size=(3, 3),
1022
+ mask=None,
1023
+ out_features=512,
1024
+ padding=(1, 1),
1025
+ param_dtype=float32,
1026
+ precision=None,
1027
+ promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
1028
+ strides=(1, 1),
1029
+ use_bias=True
1030
+ ),
1031
+ conv_resize=Conv( # Param: 36,928 (147.7 KB)
1032
+ bias=Param( # 64 (256 B)
1033
+ value=Array(shape=(64,), dtype=dtype('float32'))
1034
+ ),
1035
+ bias_init=<function zeros at 0x7fb32b98c2c0>,
1036
+ conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
1037
+ dtype=float32,
1038
+ feature_group_count=1,
1039
+ in_features=128,
1040
+ input_dilation=1,
1041
+ kernel=Param( # 36,864 (147.5 KB)
1042
+ value=Array(shape=(3, 3, 64, 64), dtype=dtype('float32'))
1043
+ ),
1044
+ kernel_dilation=1,
1045
+ kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
1046
+ kernel_shape=(3, 3, 128, 128),
1047
+ kernel_size=(3, 3),
1048
+ mask=None,
1049
+ out_features=128,
1050
+ padding=(1, 1),
1051
+ param_dtype=float32,
1052
+ precision=None,
1053
+ promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
1054
+ strides=(1, 1),
1055
+ use_bias=True
1056
+ ),
1057
+ method='pixel_shuffle',
1058
+ pixel_shuffle=PixelShuffle(
1059
+ scale=2
1060
+ ),
1061
+ scale_factor=2
1062
+ )]
1063
+ ), TimestepEmbedSequential( # Param: 152,640 (610.6 KB), RngState: 2 (12 B), Total: 152,642 (610.6 KB)
1064
+ layers=[ResnetBlock( # Param: 135,872 (543.5 KB), RngState: 2 (12 B), Total: 135,874 (543.5 KB)
1065
+ activation=<PjitFunction of <function silu at 0x7fb32b8440e0>>,
1066
+ conv1=Conv( # Param: 73,792 (295.2 KB)
1067
+ bias=Param( # 64 (256 B)
1068
+ value=Array(shape=(64,), dtype=dtype('float32'))
1069
+ ),
1070
+ bias_init=<function zeros at 0x7fb32b98c2c0>,
1071
+ conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
1072
+ dtype=float32,
1073
+ feature_group_count=1,
1074
+ in_features=256,
1075
+ input_dilation=1,
1076
+ kernel=Param( # 73,728 (294.9 KB)
1077
+ value=Array(shape=(3, 3, 128, 64), dtype=dtype('float32'))
1078
+ ),
1079
+ kernel_dilation=1,
1080
+ kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
1081
+ kernel_shape=(3, 3, 256, 128),
1082
+ kernel_size=(3, 3),
1083
+ mask=None,
1084
+ out_features=128,
1085
+ padding=(1, 1),
1086
+ param_dtype=float32,
1087
+ precision=None,
1088
+ promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
1089
+ strides=(1, 1),
1090
+ use_bias=True
1091
+ ),
1092
+ conv2=Conv( # Param: 36,928 (147.7 KB)
1093
+ bias=Param( # 64 (256 B)
1094
+ value=Array(shape=(64,), dtype=dtype('float32'))
1095
+ ),
1096
+ bias_init=<function zeros at 0x7fb32b98c2c0>,
1097
+ conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
1098
+ dtype=float32,
1099
+ feature_group_count=1,
1100
+ in_features=128,
1101
+ input_dilation=1,
1102
+ kernel=Param( # 36,864 (147.5 KB)
1103
+ value=Array(shape=(3, 3, 64, 64), dtype=dtype('float32'))
1104
+ ),
1105
+ kernel_dilation=1,
1106
+ kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
1107
+ kernel_shape=(3, 3, 128, 128),
1108
+ kernel_size=(3, 3),
1109
+ mask=None,
1110
+ out_features=128,
1111
+ padding=(1, 1),
1112
+ param_dtype=float32,
1113
+ precision=None,
1114
+ promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
1115
+ strides=(1, 1),
1116
+ use_bias=True
1117
+ ),
1118
+ dropout=Dropout( # RngState: 2 (12 B)
1119
+ broadcast_dims=(),
1120
+ deterministic=False,
1121
+ rate=0.1,
1122
+ rng_collection='dropout',
1123
+ rngs=RngStream( # RngState: 2 (12 B)
1124
+ count=RngCount( # 1 (4 B)
1125
+ value=Array(0, dtype=uint32),
1126
+ tag='default'
1127
+ ),
1128
+ key=RngKey( # 1 (8 B)
1129
+ value=Array((), dtype=key<fry>) overlaying:
1130
+ [ 48802331 1548237274],
1131
+ tag='default'
1132
+ ),
1133
+ tag='default'
1134
+ )
1135
+ ),
1136
+ embedding_dim=256,
1137
+ in_channels=256,
1138
+ nin_shortcut=Conv( # Param: 8,256 (33.0 KB)
1139
+ bias=Param( # 64 (256 B)
1140
+ value=Array(shape=(64,), dtype=dtype('float32'))
1141
+ ),
1142
+ bias_init=<function zeros at 0x7fb32b98c2c0>,
1143
+ conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
1144
+ dtype=float32,
1145
+ feature_group_count=1,
1146
+ in_features=256,
1147
+ input_dilation=1,
1148
+ kernel=Param( # 8,192 (32.8 KB)
1149
+ value=Array(shape=(1, 1, 128, 64), dtype=dtype('float32'))
1150
+ ),
1151
+ kernel_dilation=1,
1152
+ kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
1153
+ kernel_shape=(1, 1, 256, 128),
1154
+ kernel_size=(1, 1),
1155
+ mask=None,
1156
+ out_features=128,
1157
+ padding=(0, 0),
1158
+ param_dtype=float32,
1159
+ precision=None,
1160
+ promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
1161
+ strides=(1, 1),
1162
+ use_bias=True
1163
+ ),
1164
+ norm1=GroupNorm( # Param: 256 (1.0 KB)
1165
+ axis_index_groups=None,
1166
+ axis_name=None,
1167
+ bias=Param( # 128 (512 B)
1168
+ value=Array(shape=(128,), dtype=dtype('float32'))
1169
+ ),
1170
+ bias_init=<function zeros at 0x7fb32b98c2c0>,
1171
+ dtype=float32,
1172
+ epsilon=1e-06,
1173
+ feature_axis=-1,
1174
+ group_size=8,
1175
+ num_groups=32,
1176
+ param_dtype=float32,
1177
+ reduction_axes=None,
1178
+ scale=Param( # 128 (512 B)
1179
+ value=Array(shape=(128,), dtype=dtype('float32'))
1180
+ ),
1181
+ scale_init=<function ones at 0x7fb32b86e520>,
1182
+ use_bias=True,
1183
+ use_fast_variance=True,
1184
+ use_scale=True
1185
+ ),
1186
+ norm2=GroupNorm( # Param: 128 (512 B)
1187
+ axis_index_groups=None,
1188
+ axis_name=None,
1189
+ bias=Param( # 64 (256 B)
1190
+ value=Array(shape=(64,), dtype=dtype('float32'))
1191
+ ),
1192
+ bias_init=<function zeros at 0x7fb32b98c2c0>,
1193
+ dtype=float32,
1194
+ epsilon=1e-06,
1195
+ feature_axis=-1,
1196
+ group_size=4,
1197
+ num_groups=32,
1198
+ param_dtype=float32,
1199
+ reduction_axes=None,
1200
+ scale=Param( # 64 (256 B)
1201
+ value=Array(shape=(64,), dtype=dtype('float32'))
1202
+ ),
1203
+ scale_init=<function ones at 0x7fb32b86e520>,
1204
+ use_bias=True,
1205
+ use_fast_variance=True,
1206
+ use_scale=True
1207
+ ),
1208
+ out_channels=128,
1209
+ time_mlp=Linear( # Param: 16,512 (66.0 KB)
1210
+ bias=Param( # 128 (512 B)
1211
+ value=Array(shape=(128,), dtype=dtype('float32'))
1212
+ ),
1213
+ bias_init=<function zeros at 0x7fb32b98c2c0>,
1214
+ dot_general=<function dot_general at 0x7fb32c21a3e0>,
1215
+ dtype=float32,
1216
+ in_features=256,
1217
+ kernel=Param( # 16,384 (65.5 KB)
1218
+ value=Array(shape=(128, 128), dtype=dtype('float32'))
1219
+ ),
1220
+ kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
1221
+ out_features=256,
1222
+ param_dtype=float32,
1223
+ precision=None,
1224
+ promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
1225
+ use_bias=True
1226
+ )
1227
+ ), AttnBlock( # Param: 16,768 (67.1 KB)
1228
+ dtype=float32,
1229
+ head_dim=32,
1230
+ k=Conv( # Param: 4,160 (16.6 KB)
1231
+ bias=Param( # 64 (256 B)
1232
+ value=Array(shape=(64,), dtype=dtype('float32'))
1233
+ ),
1234
+ bias_init=<function zeros at 0x7fb32b98c2c0>,
1235
+ conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
1236
+ dtype=float32,
1237
+ feature_group_count=1,
1238
+ in_features=128,
1239
+ input_dilation=1,
1240
+ kernel=Param( # 4,096 (16.4 KB)
1241
+ value=Array(shape=(1, 1, 64, 64), dtype=dtype('float32'))
1242
+ ),
1243
+ kernel_dilation=1,
1244
+ kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
1245
+ kernel_shape=(1, 1, 128, 128),
1246
+ kernel_size=(1, 1),
1247
+ mask=None,
1248
+ out_features=128,
1249
+ padding='SAME',
1250
+ param_dtype=float32,
1251
+ precision=None,
1252
+ promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
1253
+ strides=1,
1254
+ use_bias=True
1255
+ ),
1256
+ norm=GroupNorm( # Param: 128 (512 B)
1257
+ axis_index_groups=None,
1258
+ axis_name=None,
1259
+ bias=Param( # 64 (256 B)
1260
+ value=Array(shape=(64,), dtype=dtype('float32'))
1261
+ ),
1262
+ bias_init=<function zeros at 0x7fb32b98c2c0>,
1263
+ dtype=float32,
1264
+ epsilon=1e-06,
1265
+ feature_axis=-1,
1266
+ group_size=4,
1267
+ num_groups=32,
1268
+ param_dtype=float32,
1269
+ reduction_axes=None,
1270
+ scale=Param( # 64 (256 B)
1271
+ value=Array(shape=(64,), dtype=dtype('float32'))
1272
+ ),
1273
+ scale_init=<function ones at 0x7fb32b86e520>,
1274
+ use_bias=True,
1275
+ use_fast_variance=True,
1276
+ use_scale=True
1277
+ ),
1278
+ num_heads=4,
1279
+ proj_out=Conv( # Param: 4,160 (16.6 KB)
1280
+ bias=Param( # 64 (256 B)
1281
+ value=Array(shape=(64,), dtype=dtype('float32'))
1282
+ ),
1283
+ bias_init=<function zeros at 0x7fb32b98c2c0>,
1284
+ conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
1285
+ dtype=float32,
1286
+ feature_group_count=1,
1287
+ in_features=128,
1288
+ input_dilation=1,
1289
+ kernel=Param( # 4,096 (16.4 KB)
1290
+ value=Array(shape=(1, 1, 64, 64), dtype=dtype('float32'))
1291
+ ),
1292
+ kernel_dilation=1,
1293
+ kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
1294
+ kernel_shape=(1, 1, 128, 128),
1295
+ kernel_size=(1, 1),
1296
+ mask=None,
1297
+ out_features=128,
1298
+ padding='SAME',
1299
+ param_dtype=float32,
1300
+ precision=None,
1301
+ promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
1302
+ strides=1,
1303
+ use_bias=True
1304
+ ),
1305
+ q=Conv( # Param: 4,160 (16.6 KB)
1306
+ bias=Param( # 64 (256 B)
1307
+ value=Array(shape=(64,), dtype=dtype('float32'))
1308
+ ),
1309
+ bias_init=<function zeros at 0x7fb32b98c2c0>,
1310
+ conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
1311
+ dtype=float32,
1312
+ feature_group_count=1,
1313
+ in_features=128,
1314
+ input_dilation=1,
1315
+ kernel=Param( # 4,096 (16.4 KB)
1316
+ value=Array(shape=(1, 1, 64, 64), dtype=dtype('float32'))
1317
+ ),
1318
+ kernel_dilation=1,
1319
+ kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
1320
+ kernel_shape=(1, 1, 128, 128),
1321
+ kernel_size=(1, 1),
1322
+ mask=None,
1323
+ out_features=128,
1324
+ padding='SAME',
1325
+ param_dtype=float32,
1326
+ precision=None,
1327
+ promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
1328
+ strides=1,
1329
+ use_bias=True
1330
+ ),
1331
+ v=Conv( # Param: 4,160 (16.6 KB)
1332
+ bias=Param( # 64 (256 B)
1333
+ value=Array(shape=(64,), dtype=dtype('float32'))
1334
+ ),
1335
+ bias_init=<function zeros at 0x7fb32b98c2c0>,
1336
+ conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
1337
+ dtype=float32,
1338
+ feature_group_count=1,
1339
+ in_features=128,
1340
+ input_dilation=1,
1341
+ kernel=Param( # 4,096 (16.4 KB)
1342
+ value=Array(shape=(1, 1, 64, 64), dtype=dtype('float32'))
1343
+ ),
1344
+ kernel_dilation=1,
1345
+ kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
1346
+ kernel_shape=(1, 1, 128, 128),
1347
+ kernel_size=(1, 1),
1348
+ mask=None,
1349
+ out_features=128,
1350
+ padding='SAME',
1351
+ param_dtype=float32,
1352
+ precision=None,
1353
+ promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
1354
+ strides=1,
1355
+ use_bias=True
1356
+ )
1357
+ )]
1358
+ ), TimestepEmbedSequential( # Param: 316,736 (1.3 MB), RngState: 2 (12 B), Total: 316,738 (1.3 MB)
1359
+ layers=[ResnetBlock( # Param: 115,328 (461.3 KB), RngState: 2 (12 B), Total: 115,330 (461.3 KB)
1360
+ activation=<PjitFunction of <function silu at 0x7fb32b8440e0>>,
1361
+ conv1=Conv( # Param: 55,360 (221.4 KB)
1362
+ bias=Param( # 64 (256 B)
1363
+ value=Array(shape=(64,), dtype=dtype('float32'))
1364
+ ),
1365
+ bias_init=<function zeros at 0x7fb32b98c2c0>,
1366
+ conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
1367
+ dtype=float32,
1368
+ feature_group_count=1,
1369
+ in_features=192,
1370
+ input_dilation=1,
1371
+ kernel=Param( # 55,296 (221.2 KB)
1372
+ value=Array(shape=(3, 3, 96, 64), dtype=dtype('float32'))
1373
+ ),
1374
+ kernel_dilation=1,
1375
+ kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
1376
+ kernel_shape=(3, 3, 192, 128),
1377
+ kernel_size=(3, 3),
1378
+ mask=None,
1379
+ out_features=128,
1380
+ padding=(1, 1),
1381
+ param_dtype=float32,
1382
+ precision=None,
1383
+ promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
1384
+ strides=(1, 1),
1385
+ use_bias=True
1386
+ ),
1387
+ conv2=Conv( # Param: 36,928 (147.7 KB)
1388
+ bias=Param( # 64 (256 B)
1389
+ value=Array(shape=(64,), dtype=dtype('float32'))
1390
+ ),
1391
+ bias_init=<function zeros at 0x7fb32b98c2c0>,
1392
+ conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
1393
+ dtype=float32,
1394
+ feature_group_count=1,
1395
+ in_features=128,
1396
+ input_dilation=1,
1397
+ kernel=Param( # 36,864 (147.5 KB)
1398
+ value=Array(shape=(3, 3, 64, 64), dtype=dtype('float32'))
1399
+ ),
1400
+ kernel_dilation=1,
1401
+ kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
1402
+ kernel_shape=(3, 3, 128, 128),
1403
+ kernel_size=(3, 3),
1404
+ mask=None,
1405
+ out_features=128,
1406
+ padding=(1, 1),
1407
+ param_dtype=float32,
1408
+ precision=None,
1409
+ promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
1410
+ strides=(1, 1),
1411
+ use_bias=True
1412
+ ),
1413
+ dropout=Dropout( # RngState: 2 (12 B)
1414
+ broadcast_dims=(),
1415
+ deterministic=False,
1416
+ rate=0.1,
1417
+ rng_collection='dropout',
1418
+ rngs=RngStream( # RngState: 2 (12 B)
1419
+ count=RngCount( # 1 (4 B)
1420
+ value=Array(0, dtype=uint32),
1421
+ tag='default'
1422
+ ),
1423
+ key=RngKey( # 1 (8 B)
1424
+ value=Array((), dtype=key<fry>) overlaying:
1425
+ [1596966061 1315822572],
1426
+ tag='default'
1427
+ ),
1428
+ tag='default'
1429
+ )
1430
+ ),
1431
+ embedding_dim=256,
1432
+ in_channels=192,
1433
+ nin_shortcut=Conv( # Param: 6,208 (24.8 KB)
1434
+ bias=Param( # 64 (256 B)
1435
+ value=Array(shape=(64,), dtype=dtype('float32'))
1436
+ ),
1437
+ bias_init=<function zeros at 0x7fb32b98c2c0>,
1438
+ conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
1439
+ dtype=float32,
1440
+ feature_group_count=1,
1441
+ in_features=192,
1442
+ input_dilation=1,
1443
+ kernel=Param( # 6,144 (24.6 KB)
1444
+ value=Array(shape=(1, 1, 96, 64), dtype=dtype('float32'))
1445
+ ),
1446
+ kernel_dilation=1,
1447
+ kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
1448
+ kernel_shape=(1, 1, 192, 128),
1449
+ kernel_size=(1, 1),
1450
+ mask=None,
1451
+ out_features=128,
1452
+ padding=(0, 0),
1453
+ param_dtype=float32,
1454
+ precision=None,
1455
+ promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
1456
+ strides=(1, 1),
1457
+ use_bias=True
1458
+ ),
1459
+ norm1=GroupNorm( # Param: 192 (768 B)
1460
+ axis_index_groups=None,
1461
+ axis_name=None,
1462
+ bias=Param( # 96 (384 B)
1463
+ value=Array(shape=(96,), dtype=dtype('float32'))
1464
+ ),
1465
+ bias_init=<function zeros at 0x7fb32b98c2c0>,
1466
+ dtype=float32,
1467
+ epsilon=1e-06,
1468
+ feature_axis=-1,
1469
+ group_size=6,
1470
+ num_groups=32,
1471
+ param_dtype=float32,
1472
+ reduction_axes=None,
1473
+ scale=Param( # 96 (384 B)
1474
+ value=Array(shape=(96,), dtype=dtype('float32'))
1475
+ ),
1476
+ scale_init=<function ones at 0x7fb32b86e520>,
1477
+ use_bias=True,
1478
+ use_fast_variance=True,
1479
+ use_scale=True
1480
+ ),
1481
+ norm2=GroupNorm( # Param: 128 (512 B)
1482
+ axis_index_groups=None,
1483
+ axis_name=None,
1484
+ bias=Param( # 64 (256 B)
1485
+ value=Array(shape=(64,), dtype=dtype('float32'))
1486
+ ),
1487
+ bias_init=<function zeros at 0x7fb32b98c2c0>,
1488
+ dtype=float32,
1489
+ epsilon=1e-06,
1490
+ feature_axis=-1,
1491
+ group_size=4,
1492
+ num_groups=32,
1493
+ param_dtype=float32,
1494
+ reduction_axes=None,
1495
+ scale=Param( # 64 (256 B)
1496
+ value=Array(shape=(64,), dtype=dtype('float32'))
1497
+ ),
1498
+ scale_init=<function ones at 0x7fb32b86e520>,
1499
+ use_bias=True,
1500
+ use_fast_variance=True,
1501
+ use_scale=True
1502
+ ),
1503
+ out_channels=128,
1504
+ time_mlp=Linear( # Param: 16,512 (66.0 KB)
1505
+ bias=Param( # 128 (512 B)
1506
+ value=Array(shape=(128,), dtype=dtype('float32'))
1507
+ ),
1508
+ bias_init=<function zeros at 0x7fb32b98c2c0>,
1509
+ dot_general=<function dot_general at 0x7fb32c21a3e0>,
1510
+ dtype=float32,
1511
+ in_features=256,
1512
+ kernel=Param( # 16,384 (65.5 KB)
1513
+ value=Array(shape=(128, 128), dtype=dtype('float32'))
1514
+ ),
1515
+ kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
1516
+ out_features=256,
1517
+ param_dtype=float32,
1518
+ precision=None,
1519
+ promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
1520
+ use_bias=True
1521
+ )
1522
+ ), AttnBlock( # Param: 16,768 (67.1 KB)
1523
+ dtype=float32,
1524
+ head_dim=32,
1525
+ k=Conv( # Param: 4,160 (16.6 KB)
1526
+ bias=Param( # 64 (256 B)
1527
+ value=Array(shape=(64,), dtype=dtype('float32'))
1528
+ ),
1529
+ bias_init=<function zeros at 0x7fb32b98c2c0>,
1530
+ conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
1531
+ dtype=float32,
1532
+ feature_group_count=1,
1533
+ in_features=128,
1534
+ input_dilation=1,
1535
+ kernel=Param( # 4,096 (16.4 KB)
1536
+ value=Array(shape=(1, 1, 64, 64), dtype=dtype('float32'))
1537
+ ),
1538
+ kernel_dilation=1,
1539
+ kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
1540
+ kernel_shape=(1, 1, 128, 128),
1541
+ kernel_size=(1, 1),
1542
+ mask=None,
1543
+ out_features=128,
1544
+ padding='SAME',
1545
+ param_dtype=float32,
1546
+ precision=None,
1547
+ promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
1548
+ strides=1,
1549
+ use_bias=True
1550
+ ),
1551
+ norm=GroupNorm( # Param: 128 (512 B)
1552
+ axis_index_groups=None,
1553
+ axis_name=None,
1554
+ bias=Param( # 64 (256 B)
1555
+ value=Array(shape=(64,), dtype=dtype('float32'))
1556
+ ),
1557
+ bias_init=<function zeros at 0x7fb32b98c2c0>,
1558
+ dtype=float32,
1559
+ epsilon=1e-06,
1560
+ feature_axis=-1,
1561
+ group_size=4,
1562
+ num_groups=32,
1563
+ param_dtype=float32,
1564
+ reduction_axes=None,
1565
+ scale=Param( # 64 (256 B)
1566
+ value=Array(shape=(64,), dtype=dtype('float32'))
1567
+ ),
1568
+ scale_init=<function ones at 0x7fb32b86e520>,
1569
+ use_bias=True,
1570
+ use_fast_variance=True,
1571
+ use_scale=True
1572
+ ),
1573
+ num_heads=4,
1574
+ proj_out=Conv( # Param: 4,160 (16.6 KB)
1575
+ bias=Param( # 64 (256 B)
1576
+ value=Array(shape=(64,), dtype=dtype('float32'))
1577
+ ),
1578
+ bias_init=<function zeros at 0x7fb32b98c2c0>,
1579
+ conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
1580
+ dtype=float32,
1581
+ feature_group_count=1,
1582
+ in_features=128,
1583
+ input_dilation=1,
1584
+ kernel=Param( # 4,096 (16.4 KB)
1585
+ value=Array(shape=(1, 1, 64, 64), dtype=dtype('float32'))
1586
+ ),
1587
+ kernel_dilation=1,
1588
+ kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
1589
+ kernel_shape=(1, 1, 128, 128),
1590
+ kernel_size=(1, 1),
1591
+ mask=None,
1592
+ out_features=128,
1593
+ padding='SAME',
1594
+ param_dtype=float32,
1595
+ precision=None,
1596
+ promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
1597
+ strides=1,
1598
+ use_bias=True
1599
+ ),
1600
+ q=Conv( # Param: 4,160 (16.6 KB)
1601
+ bias=Param( # 64 (256 B)
1602
+ value=Array(shape=(64,), dtype=dtype('float32'))
1603
+ ),
1604
+ bias_init=<function zeros at 0x7fb32b98c2c0>,
1605
+ conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
1606
+ dtype=float32,
1607
+ feature_group_count=1,
1608
+ in_features=128,
1609
+ input_dilation=1,
1610
+ kernel=Param( # 4,096 (16.4 KB)
1611
+ value=Array(shape=(1, 1, 64, 64), dtype=dtype('float32'))
1612
+ ),
1613
+ kernel_dilation=1,
1614
+ kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
1615
+ kernel_shape=(1, 1, 128, 128),
1616
+ kernel_size=(1, 1),
1617
+ mask=None,
1618
+ out_features=128,
1619
+ padding='SAME',
1620
+ param_dtype=float32,
1621
+ precision=None,
1622
+ promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
1623
+ strides=1,
1624
+ use_bias=True
1625
+ ),
1626
+ v=Conv( # Param: 4,160 (16.6 KB)
1627
+ bias=Param( # 64 (256 B)
1628
+ value=Array(shape=(64,), dtype=dtype('float32'))
1629
+ ),
1630
+ bias_init=<function zeros at 0x7fb32b98c2c0>,
1631
+ conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
1632
+ dtype=float32,
1633
+ feature_group_count=1,
1634
+ in_features=128,
1635
+ input_dilation=1,
1636
+ kernel=Param( # 4,096 (16.4 KB)
1637
+ value=Array(shape=(1, 1, 64, 64), dtype=dtype('float32'))
1638
+ ),
1639
+ kernel_dilation=1,
1640
+ kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
1641
+ kernel_shape=(1, 1, 128, 128),
1642
+ kernel_size=(1, 1),
1643
+ mask=None,
1644
+ out_features=128,
1645
+ padding='SAME',
1646
+ param_dtype=float32,
1647
+ precision=None,
1648
+ promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
1649
+ strides=1,
1650
+ use_bias=True
1651
+ )
1652
+ ), Upsample( # Param: 184,640 (738.6 KB)
1653
+ conv_pixel_shuffle=Conv( # Param: 147,712 (590.8 KB)
1654
+ bias=Param( # 256 (1.0 KB)
1655
+ value=Array(shape=(256,), dtype=dtype('float32'))
1656
+ ),
1657
+ bias_init=<function zeros at 0x7fb32b98c2c0>,
1658
+ conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
1659
+ dtype=float32,
1660
+ feature_group_count=1,
1661
+ in_features=128,
1662
+ input_dilation=1,
1663
+ kernel=Param( # 147,456 (589.8 KB)
1664
+ value=Array(shape=(3, 3, 64, 256), dtype=dtype('float32'))
1665
+ ),
1666
+ kernel_dilation=1,
1667
+ kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
1668
+ kernel_shape=(3, 3, 128, 512),
1669
+ kernel_size=(3, 3),
1670
+ mask=None,
1671
+ out_features=512,
1672
+ padding=(1, 1),
1673
+ param_dtype=float32,
1674
+ precision=None,
1675
+ promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
1676
+ strides=(1, 1),
1677
+ use_bias=True
1678
+ ),
1679
+ conv_resize=Conv( # Param: 36,928 (147.7 KB)
1680
+ bias=Param( # 64 (256 B)
1681
+ value=Array(shape=(64,), dtype=dtype('float32'))
1682
+ ),
1683
+ bias_init=<function zeros at 0x7fb32b98c2c0>,
1684
+ conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
1685
+ dtype=float32,
1686
+ feature_group_count=1,
1687
+ in_features=128,
1688
+ input_dilation=1,
1689
+ kernel=Param( # 36,864 (147.5 KB)
1690
+ value=Array(shape=(3, 3, 64, 64), dtype=dtype('float32'))
1691
+ ),
1692
+ kernel_dilation=1,
1693
+ kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
1694
+ kernel_shape=(3, 3, 128, 128),
1695
+ kernel_size=(3, 3),
1696
+ mask=None,
1697
+ out_features=128,
1698
+ padding=(1, 1),
1699
+ param_dtype=float32,
1700
+ precision=None,
1701
+ promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
1702
+ strides=(1, 1),
1703
+ use_bias=True
1704
+ ),
1705
+ method='pixel_shuffle',
1706
+ pixel_shuffle=PixelShuffle(
1707
+ scale=2
1708
+ ),
1709
+ scale_factor=2
1710
+ )]
1711
+ ), TimestepEmbedSequential( # Param: 48,544 (194.2 KB), RngState: 2 (12 B), Total: 48,546 (194.2 KB)
1712
+ layers=[ResnetBlock( # Param: 48,544 (194.2 KB), RngState: 2 (12 B), Total: 48,546 (194.2 KB)
1713
+ activation=<PjitFunction of <function silu at 0x7fb32b8440e0>>,
1714
+ conv1=Conv( # Param: 27,680 (110.7 KB)
1715
+ bias=Param( # 32 (128 B)
1716
+ value=Array(shape=(32,), dtype=dtype('float32'))
1717
+ ),
1718
+ bias_init=<function zeros at 0x7fb32b98c2c0>,
1719
+ conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
1720
+ dtype=float32,
1721
+ feature_group_count=1,
1722
+ in_features=192,
1723
+ input_dilation=1,
1724
+ kernel=Param( # 27,648 (110.6 KB)
1725
+ value=Array(shape=(3, 3, 96, 32), dtype=dtype('float32'))
1726
+ ),
1727
+ kernel_dilation=1,
1728
+ kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
1729
+ kernel_shape=(3, 3, 192, 64),
1730
+ kernel_size=(3, 3),
1731
+ mask=None,
1732
+ out_features=64,
1733
+ padding=(1, 1),
1734
+ param_dtype=float32,
1735
+ precision=None,
1736
+ promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
1737
+ strides=(1, 1),
1738
+ use_bias=True
1739
+ ),
1740
+ conv2=Conv( # Param: 9,248 (37.0 KB)
1741
+ bias=Param( # 32 (128 B)
1742
+ value=Array(shape=(32,), dtype=dtype('float32'))
1743
+ ),
1744
+ bias_init=<function zeros at 0x7fb32b98c2c0>,
1745
+ conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
1746
+ dtype=float32,
1747
+ feature_group_count=1,
1748
+ in_features=64,
1749
+ input_dilation=1,
1750
+ kernel=Param( # 9,216 (36.9 KB)
1751
+ value=Array(shape=(3, 3, 32, 32), dtype=dtype('float32'))
1752
+ ),
1753
+ kernel_dilation=1,
1754
+ kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
1755
+ kernel_shape=(3, 3, 64, 64),
1756
+ kernel_size=(3, 3),
1757
+ mask=None,
1758
+ out_features=64,
1759
+ padding=(1, 1),
1760
+ param_dtype=float32,
1761
+ precision=None,
1762
+ promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
1763
+ strides=(1, 1),
1764
+ use_bias=True
1765
+ ),
1766
+ dropout=Dropout( # RngState: 2 (12 B)
1767
+ broadcast_dims=(),
1768
+ deterministic=False,
1769
+ rate=0.1,
1770
+ rng_collection='dropout',
1771
+ rngs=RngStream( # RngState: 2 (12 B)
1772
+ count=RngCount( # 1 (4 B)
1773
+ value=Array(0, dtype=uint32),
1774
+ tag='default'
1775
+ ),
1776
+ key=RngKey( # 1 (8 B)
1777
+ value=Array((), dtype=key<fry>) overlaying:
1778
+ [2550820645 2818876438],
1779
+ tag='default'
1780
+ ),
1781
+ tag='default'
1782
+ )
1783
+ ),
1784
+ embedding_dim=256,
1785
+ in_channels=192,
1786
+ nin_shortcut=Conv( # Param: 3,104 (12.4 KB)
1787
+ bias=Param( # 32 (128 B)
1788
+ value=Array(shape=(32,), dtype=dtype('float32'))
1789
+ ),
1790
+ bias_init=<function zeros at 0x7fb32b98c2c0>,
1791
+ conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
1792
+ dtype=float32,
1793
+ feature_group_count=1,
1794
+ in_features=192,
1795
+ input_dilation=1,
1796
+ kernel=Param( # 3,072 (12.3 KB)
1797
+ value=Array(shape=(1, 1, 96, 32), dtype=dtype('float32'))
1798
+ ),
1799
+ kernel_dilation=1,
1800
+ kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
1801
+ kernel_shape=(1, 1, 192, 64),
1802
+ kernel_size=(1, 1),
1803
+ mask=None,
1804
+ out_features=64,
1805
+ padding=(0, 0),
1806
+ param_dtype=float32,
1807
+ precision=None,
1808
+ promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
1809
+ strides=(1, 1),
1810
+ use_bias=True
1811
+ ),
1812
+ norm1=GroupNorm( # Param: 192 (768 B)
1813
+ axis_index_groups=None,
1814
+ axis_name=None,
1815
+ bias=Param( # 96 (384 B)
1816
+ value=Array(shape=(96,), dtype=dtype('float32'))
1817
+ ),
1818
+ bias_init=<function zeros at 0x7fb32b98c2c0>,
1819
+ dtype=float32,
1820
+ epsilon=1e-06,
1821
+ feature_axis=-1,
1822
+ group_size=6,
1823
+ num_groups=32,
1824
+ param_dtype=float32,
1825
+ reduction_axes=None,
1826
+ scale=Param( # 96 (384 B)
1827
+ value=Array(shape=(96,), dtype=dtype('float32'))
1828
+ ),
1829
+ scale_init=<function ones at 0x7fb32b86e520>,
1830
+ use_bias=True,
1831
+ use_fast_variance=True,
1832
+ use_scale=True
1833
+ ),
1834
+ norm2=GroupNorm( # Param: 64 (256 B)
1835
+ axis_index_groups=None,
1836
+ axis_name=None,
1837
+ bias=Param( # 32 (128 B)
1838
+ value=Array(shape=(32,), dtype=dtype('float32'))
1839
+ ),
1840
+ bias_init=<function zeros at 0x7fb32b98c2c0>,
1841
+ dtype=float32,
1842
+ epsilon=1e-06,
1843
+ feature_axis=-1,
1844
+ group_size=2,
1845
+ num_groups=32,
1846
+ param_dtype=float32,
1847
+ reduction_axes=None,
1848
+ scale=Param( # 32 (128 B)
1849
+ value=Array(shape=(32,), dtype=dtype('float32'))
1850
+ ),
1851
+ scale_init=<function ones at 0x7fb32b86e520>,
1852
+ use_bias=True,
1853
+ use_fast_variance=True,
1854
+ use_scale=True
1855
+ ),
1856
+ out_channels=64,
1857
+ time_mlp=Linear( # Param: 8,256 (33.0 KB)
1858
+ bias=Param( # 64 (256 B)
1859
+ value=Array(shape=(64,), dtype=dtype('float32'))
1860
+ ),
1861
+ bias_init=<function zeros at 0x7fb32b98c2c0>,
1862
+ dot_general=<function dot_general at 0x7fb32c21a3e0>,
1863
+ dtype=float32,
1864
+ in_features=256,
1865
+ kernel=Param( # 8,192 (32.8 KB)
1866
+ value=Array(shape=(128, 64), dtype=dtype('float32'))
1867
+ ),
1868
+ kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
1869
+ out_features=128,
1870
+ param_dtype=float32,
1871
+ precision=None,
1872
+ promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
1873
+ use_bias=True
1874
+ )
1875
+ )]
1876
+ ), TimestepEmbedSequential( # Param: 38,240 (153.0 KB), RngState: 2 (12 B), Total: 38,242 (153.0 KB)
1877
+ layers=[ResnetBlock( # Param: 38,240 (153.0 KB), RngState: 2 (12 B), Total: 38,242 (153.0 KB)
1878
+ activation=<PjitFunction of <function silu at 0x7fb32b8440e0>>,
1879
+ conv1=Conv( # Param: 18,464 (73.9 KB)
1880
+ bias=Param( # 32 (128 B)
1881
+ value=Array(shape=(32,), dtype=dtype('float32'))
1882
+ ),
1883
+ bias_init=<function zeros at 0x7fb32b98c2c0>,
1884
+ conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
1885
+ dtype=float32,
1886
+ feature_group_count=1,
1887
+ in_features=128,
1888
+ input_dilation=1,
1889
+ kernel=Param( # 18,432 (73.7 KB)
1890
+ value=Array(shape=(3, 3, 64, 32), dtype=dtype('float32'))
1891
+ ),
1892
+ kernel_dilation=1,
1893
+ kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
1894
+ kernel_shape=(3, 3, 128, 64),
1895
+ kernel_size=(3, 3),
1896
+ mask=None,
1897
+ out_features=64,
1898
+ padding=(1, 1),
1899
+ param_dtype=float32,
1900
+ precision=None,
1901
+ promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
1902
+ strides=(1, 1),
1903
+ use_bias=True
1904
+ ),
1905
+ conv2=Conv( # Param: 9,248 (37.0 KB)
1906
+ bias=Param( # 32 (128 B)
1907
+ value=Array(shape=(32,), dtype=dtype('float32'))
1908
+ ),
1909
+ bias_init=<function zeros at 0x7fb32b98c2c0>,
1910
+ conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
1911
+ dtype=float32,
1912
+ feature_group_count=1,
1913
+ in_features=64,
1914
+ input_dilation=1,
1915
+ kernel=Param( # 9,216 (36.9 KB)
1916
+ value=Array(shape=(3, 3, 32, 32), dtype=dtype('float32'))
1917
+ ),
1918
+ kernel_dilation=1,
1919
+ kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
1920
+ kernel_shape=(3, 3, 64, 64),
1921
+ kernel_size=(3, 3),
1922
+ mask=None,
1923
+ out_features=64,
1924
+ padding=(1, 1),
1925
+ param_dtype=float32,
1926
+ precision=None,
1927
+ promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
1928
+ strides=(1, 1),
1929
+ use_bias=True
1930
+ ),
1931
+ dropout=Dropout( # RngState: 2 (12 B)
1932
+ broadcast_dims=(),
1933
+ deterministic=False,
1934
+ rate=0.1,
1935
+ rng_collection='dropout',
1936
+ rngs=RngStream( # RngState: 2 (12 B)
1937
+ count=RngCount( # 1 (4 B)
1938
+ value=Array(0, dtype=uint32),
1939
+ tag='default'
1940
+ ),
1941
+ key=RngKey( # 1 (8 B)
1942
+ value=Array((), dtype=key<fry>) overlaying:
1943
+ [1975238715 3717004500],
1944
+ tag='default'
1945
+ ),
1946
+ tag='default'
1947
+ )
1948
+ ),
1949
+ embedding_dim=256,
1950
+ in_channels=128,
1951
+ nin_shortcut=Conv( # Param: 2,080 (8.3 KB)
1952
+ bias=Param( # 32 (128 B)
1953
+ value=Array(shape=(32,), dtype=dtype('float32'))
1954
+ ),
1955
+ bias_init=<function zeros at 0x7fb32b98c2c0>,
1956
+ conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
1957
+ dtype=float32,
1958
+ feature_group_count=1,
1959
+ in_features=128,
1960
+ input_dilation=1,
1961
+ kernel=Param( # 2,048 (8.2 KB)
1962
+ value=Array(shape=(1, 1, 64, 32), dtype=dtype('float32'))
1963
+ ),
1964
+ kernel_dilation=1,
1965
+ kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
1966
+ kernel_shape=(1, 1, 128, 64),
1967
+ kernel_size=(1, 1),
1968
+ mask=None,
1969
+ out_features=64,
1970
+ padding=(0, 0),
1971
+ param_dtype=float32,
1972
+ precision=None,
1973
+ promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
1974
+ strides=(1, 1),
1975
+ use_bias=True
1976
+ ),
1977
+ norm1=GroupNorm( # Param: 128 (512 B)
1978
+ axis_index_groups=None,
1979
+ axis_name=None,
1980
+ bias=Param( # 64 (256 B)
1981
+ value=Array(shape=(64,), dtype=dtype('float32'))
1982
+ ),
1983
+ bias_init=<function zeros at 0x7fb32b98c2c0>,
1984
+ dtype=float32,
1985
+ epsilon=1e-06,
1986
+ feature_axis=-1,
1987
+ group_size=4,
1988
+ num_groups=32,
1989
+ param_dtype=float32,
1990
+ reduction_axes=None,
1991
+ scale=Param( # 64 (256 B)
1992
+ value=Array(shape=(64,), dtype=dtype('float32'))
1993
+ ),
1994
+ scale_init=<function ones at 0x7fb32b86e520>,
1995
+ use_bias=True,
1996
+ use_fast_variance=True,
1997
+ use_scale=True
1998
+ ),
1999
+ norm2=GroupNorm( # Param: 64 (256 B)
2000
+ axis_index_groups=None,
2001
+ axis_name=None,
2002
+ bias=Param( # 32 (128 B)
2003
+ value=Array(shape=(32,), dtype=dtype('float32'))
2004
+ ),
2005
+ bias_init=<function zeros at 0x7fb32b98c2c0>,
2006
+ dtype=float32,
2007
+ epsilon=1e-06,
2008
+ feature_axis=-1,
2009
+ group_size=2,
2010
+ num_groups=32,
2011
+ param_dtype=float32,
2012
+ reduction_axes=None,
2013
+ scale=Param( # 32 (128 B)
2014
+ value=Array(shape=(32,), dtype=dtype('float32'))
2015
+ ),
2016
+ scale_init=<function ones at 0x7fb32b86e520>,
2017
+ use_bias=True,
2018
+ use_fast_variance=True,
2019
+ use_scale=True
2020
+ ),
2021
+ out_channels=64,
2022
+ time_mlp=Linear( # Param: 8,256 (33.0 KB)
2023
+ bias=Param( # 64 (256 B)
2024
+ value=Array(shape=(64,), dtype=dtype('float32'))
2025
+ ),
2026
+ bias_init=<function zeros at 0x7fb32b98c2c0>,
2027
+ dot_general=<function dot_general at 0x7fb32c21a3e0>,
2028
+ dtype=float32,
2029
+ in_features=256,
2030
+ kernel=Param( # 8,192 (32.8 KB)
2031
+ value=Array(shape=(128, 64), dtype=dtype('float32'))
2032
+ ),
2033
+ kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
2034
+ out_features=128,
2035
+ param_dtype=float32,
2036
+ precision=None,
2037
+ promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
2038
+ use_bias=True
2039
+ )
2040
+ )]
2041
+ )],
2042
  )