mnist / config.py
jcopo's picture
Upload model at step 45000
0a02da1 verified
raw
history blame
66 kB
"""Model configuration for jcopo/mnist
This file contains the model architecture definition.
Training step: 45000
Precision: float32
"""
from triax.models.nn.condUNet import CondUNet2D
import jax.numpy as jnp
from flax import nnx
# Model architecture
model = CondUNet2D(
blocks_down=[TimestepEmbedSequential( # Param: 320 (1.3 KB)
layers=[Conv( # Param: 320 (1.3 KB)
bias=Param( # 32 (128 B)
value=Array(shape=(32,), dtype=dtype('float32'))
),
bias_init=<function zeros at 0x7fb32b98c2c0>,
conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
dtype=float32,
feature_group_count=1,
in_features=1,
input_dilation=1,
kernel=Param( # 288 (1.2 KB)
value=Array(shape=(3, 3, 1, 32), dtype=dtype('float32'))
),
kernel_dilation=1,
kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
kernel_shape=(3, 3, 1, 64),
kernel_size=(3, 3),
mask=None,
out_features=64,
padding=(1, 1),
param_dtype=float32,
precision=None,
promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
strides=(1, 1),
use_bias=True
)]
), TimestepEmbedSequential( # Param: 26,880 (107.5 KB), RngState: 2 (12 B), Total: 26,882 (107.5 KB)
layers=[ResnetBlock( # Param: 26,880 (107.5 KB), RngState: 2 (12 B), Total: 26,882 (107.5 KB)
activation=<PjitFunction of <function silu at 0x7fb32b8440e0>>,
conv1=Conv( # Param: 9,248 (37.0 KB)
bias=Param( # 32 (128 B)
value=Array(shape=(32,), dtype=dtype('float32'))
),
bias_init=<function zeros at 0x7fb32b98c2c0>,
conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
dtype=float32,
feature_group_count=1,
in_features=64,
input_dilation=1,
kernel=Param( # 9,216 (36.9 KB)
value=Array(shape=(3, 3, 32, 32), dtype=dtype('float32'))
),
kernel_dilation=1,
kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
kernel_shape=(3, 3, 64, 64),
kernel_size=(3, 3),
mask=None,
out_features=64,
padding=(1, 1),
param_dtype=float32,
precision=None,
promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
strides=(1, 1),
use_bias=True
),
conv2=Conv( # Param: 9,248 (37.0 KB)
bias=Param( # 32 (128 B)
value=Array(shape=(32,), dtype=dtype('float32'))
),
bias_init=<function zeros at 0x7fb32b98c2c0>,
conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
dtype=float32,
feature_group_count=1,
in_features=64,
input_dilation=1,
kernel=Param( # 9,216 (36.9 KB)
value=Array(shape=(3, 3, 32, 32), dtype=dtype('float32'))
),
kernel_dilation=1,
kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
kernel_shape=(3, 3, 64, 64),
kernel_size=(3, 3),
mask=None,
out_features=64,
padding=(1, 1),
param_dtype=float32,
precision=None,
promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
strides=(1, 1),
use_bias=True
),
dropout=Dropout( # RngState: 2 (12 B)
broadcast_dims=(),
deterministic=False,
rate=0.1,
rng_collection='dropout',
rngs=RngStream( # RngState: 2 (12 B)
count=RngCount( # 1 (4 B)
value=Array(0, dtype=uint32),
tag='default'
),
key=RngKey( # 1 (8 B)
value=Array((), dtype=key<fry>) overlaying:
[2585633080 2083471411],
tag='default'
),
tag='default'
)
),
embedding_dim=256,
in_channels=64,
norm1=GroupNorm( # Param: 64 (256 B)
axis_index_groups=None,
axis_name=None,
bias=Param( # 32 (128 B)
value=Array(shape=(32,), dtype=dtype('float32'))
),
bias_init=<function zeros at 0x7fb32b98c2c0>,
dtype=float32,
epsilon=1e-06,
feature_axis=-1,
group_size=2,
num_groups=32,
param_dtype=float32,
reduction_axes=None,
scale=Param( # 32 (128 B)
value=Array(shape=(32,), dtype=dtype('float32'))
),
scale_init=<function ones at 0x7fb32b86e520>,
use_bias=True,
use_fast_variance=True,
use_scale=True
),
norm2=GroupNorm( # Param: 64 (256 B)
axis_index_groups=None,
axis_name=None,
bias=Param( # 32 (128 B)
value=Array(shape=(32,), dtype=dtype('float32'))
),
bias_init=<function zeros at 0x7fb32b98c2c0>,
dtype=float32,
epsilon=1e-06,
feature_axis=-1,
group_size=2,
num_groups=32,
param_dtype=float32,
reduction_axes=None,
scale=Param( # 32 (128 B)
value=Array(shape=(32,), dtype=dtype('float32'))
),
scale_init=<function ones at 0x7fb32b86e520>,
use_bias=True,
use_fast_variance=True,
use_scale=True
),
out_channels=64,
time_mlp=Linear( # Param: 8,256 (33.0 KB)
bias=Param( # 64 (256 B)
value=Array(shape=(64,), dtype=dtype('float32'))
),
bias_init=<function zeros at 0x7fb32b98c2c0>,
dot_general=<function dot_general at 0x7fb32c21a3e0>,
dtype=float32,
in_features=256,
kernel=Param( # 8,192 (32.8 KB)
value=Array(shape=(128, 64), dtype=dtype('float32'))
),
kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
out_features=128,
param_dtype=float32,
precision=None,
promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
use_bias=True
)
)]
), TimestepEmbedSequential( # Param: 9,248 (37.0 KB)
layers=[Downsample( # Param: 9,248 (37.0 KB)
conv=Conv( # Param: 9,248 (37.0 KB)
bias=Param( # 32 (128 B)
value=Array(shape=(32,), dtype=dtype('float32'))
),
bias_init=<function zeros at 0x7fb32b98c2c0>,
conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
dtype=float32,
feature_group_count=1,
in_features=64,
input_dilation=1,
kernel=Param( # 9,216 (36.9 KB)
value=Array(shape=(3, 3, 32, 32), dtype=dtype('float32'))
),
kernel_dilation=1,
kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
kernel_shape=(3, 3, 64, 64),
kernel_size=(3, 3),
mask=None,
out_features=64,
padding=(0, 0),
param_dtype=float32,
precision=None,
promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
strides=(2, 2),
use_bias=True
),
method='conv'
)]
), TimestepEmbedSequential( # Param: 91,008 (364.0 KB), RngState: 2 (12 B), Total: 91,010 (364.0 KB)
layers=[ResnetBlock( # Param: 74,240 (297.0 KB), RngState: 2 (12 B), Total: 74,242 (297.0 KB)
activation=<PjitFunction of <function silu at 0x7fb32b8440e0>>,
conv1=Conv( # Param: 18,496 (74.0 KB)
bias=Param( # 64 (256 B)
value=Array(shape=(64,), dtype=dtype('float32'))
),
bias_init=<function zeros at 0x7fb32b98c2c0>,
conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
dtype=float32,
feature_group_count=1,
in_features=64,
input_dilation=1,
kernel=Param( # 18,432 (73.7 KB)
value=Array(shape=(3, 3, 32, 64), dtype=dtype('float32'))
),
kernel_dilation=1,
kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
kernel_shape=(3, 3, 64, 128),
kernel_size=(3, 3),
mask=None,
out_features=128,
padding=(1, 1),
param_dtype=float32,
precision=None,
promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
strides=(1, 1),
use_bias=True
),
conv2=Conv( # Param: 36,928 (147.7 KB)
bias=Param( # 64 (256 B)
value=Array(shape=(64,), dtype=dtype('float32'))
),
bias_init=<function zeros at 0x7fb32b98c2c0>,
conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
dtype=float32,
feature_group_count=1,
in_features=128,
input_dilation=1,
kernel=Param( # 36,864 (147.5 KB)
value=Array(shape=(3, 3, 64, 64), dtype=dtype('float32'))
),
kernel_dilation=1,
kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
kernel_shape=(3, 3, 128, 128),
kernel_size=(3, 3),
mask=None,
out_features=128,
padding=(1, 1),
param_dtype=float32,
precision=None,
promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
strides=(1, 1),
use_bias=True
),
dropout=Dropout( # RngState: 2 (12 B)
broadcast_dims=(),
deterministic=False,
rate=0.1,
rng_collection='dropout',
rngs=RngStream( # RngState: 2 (12 B)
count=RngCount( # 1 (4 B)
value=Array(0, dtype=uint32),
tag='default'
),
key=RngKey( # 1 (8 B)
value=Array((), dtype=key<fry>) overlaying:
[2656139193 2766658851],
tag='default'
),
tag='default'
)
),
embedding_dim=256,
in_channels=64,
nin_shortcut=Conv( # Param: 2,112 (8.4 KB)
bias=Param( # 64 (256 B)
value=Array(shape=(64,), dtype=dtype('float32'))
),
bias_init=<function zeros at 0x7fb32b98c2c0>,
conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
dtype=float32,
feature_group_count=1,
in_features=64,
input_dilation=1,
kernel=Param( # 2,048 (8.2 KB)
value=Array(shape=(1, 1, 32, 64), dtype=dtype('float32'))
),
kernel_dilation=1,
kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
kernel_shape=(1, 1, 64, 128),
kernel_size=(1, 1),
mask=None,
out_features=128,
padding=(0, 0),
param_dtype=float32,
precision=None,
promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
strides=(1, 1),
use_bias=True
),
norm1=GroupNorm( # Param: 64 (256 B)
axis_index_groups=None,
axis_name=None,
bias=Param( # 32 (128 B)
value=Array(shape=(32,), dtype=dtype('float32'))
),
bias_init=<function zeros at 0x7fb32b98c2c0>,
dtype=float32,
epsilon=1e-06,
feature_axis=-1,
group_size=2,
num_groups=32,
param_dtype=float32,
reduction_axes=None,
scale=Param( # 32 (128 B)
value=Array(shape=(32,), dtype=dtype('float32'))
),
scale_init=<function ones at 0x7fb32b86e520>,
use_bias=True,
use_fast_variance=True,
use_scale=True
),
norm2=GroupNorm( # Param: 128 (512 B)
axis_index_groups=None,
axis_name=None,
bias=Param( # 64 (256 B)
value=Array(shape=(64,), dtype=dtype('float32'))
),
bias_init=<function zeros at 0x7fb32b98c2c0>,
dtype=float32,
epsilon=1e-06,
feature_axis=-1,
group_size=4,
num_groups=32,
param_dtype=float32,
reduction_axes=None,
scale=Param( # 64 (256 B)
value=Array(shape=(64,), dtype=dtype('float32'))
),
scale_init=<function ones at 0x7fb32b86e520>,
use_bias=True,
use_fast_variance=True,
use_scale=True
),
out_channels=128,
time_mlp=Linear( # Param: 16,512 (66.0 KB)
bias=Param( # 128 (512 B)
value=Array(shape=(128,), dtype=dtype('float32'))
),
bias_init=<function zeros at 0x7fb32b98c2c0>,
dot_general=<function dot_general at 0x7fb32c21a3e0>,
dtype=float32,
in_features=256,
kernel=Param( # 16,384 (65.5 KB)
value=Array(shape=(128, 128), dtype=dtype('float32'))
),
kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
out_features=256,
param_dtype=float32,
precision=None,
promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
use_bias=True
)
), AttnBlock( # Param: 16,768 (67.1 KB)
dtype=float32,
head_dim=32,
k=Conv( # Param: 4,160 (16.6 KB)
bias=Param( # 64 (256 B)
value=Array(shape=(64,), dtype=dtype('float32'))
),
bias_init=<function zeros at 0x7fb32b98c2c0>,
conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
dtype=float32,
feature_group_count=1,
in_features=128,
input_dilation=1,
kernel=Param( # 4,096 (16.4 KB)
value=Array(shape=(1, 1, 64, 64), dtype=dtype('float32'))
),
kernel_dilation=1,
kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
kernel_shape=(1, 1, 128, 128),
kernel_size=(1, 1),
mask=None,
out_features=128,
padding='SAME',
param_dtype=float32,
precision=None,
promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
strides=1,
use_bias=True
),
norm=GroupNorm( # Param: 128 (512 B)
axis_index_groups=None,
axis_name=None,
bias=Param( # 64 (256 B)
value=Array(shape=(64,), dtype=dtype('float32'))
),
bias_init=<function zeros at 0x7fb32b98c2c0>,
dtype=float32,
epsilon=1e-06,
feature_axis=-1,
group_size=4,
num_groups=32,
param_dtype=float32,
reduction_axes=None,
scale=Param( # 64 (256 B)
value=Array(shape=(64,), dtype=dtype('float32'))
),
scale_init=<function ones at 0x7fb32b86e520>,
use_bias=True,
use_fast_variance=True,
use_scale=True
),
num_heads=4,
proj_out=Conv( # Param: 4,160 (16.6 KB)
bias=Param( # 64 (256 B)
value=Array(shape=(64,), dtype=dtype('float32'))
),
bias_init=<function zeros at 0x7fb32b98c2c0>,
conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
dtype=float32,
feature_group_count=1,
in_features=128,
input_dilation=1,
kernel=Param( # 4,096 (16.4 KB)
value=Array(shape=(1, 1, 64, 64), dtype=dtype('float32'))
),
kernel_dilation=1,
kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
kernel_shape=(1, 1, 128, 128),
kernel_size=(1, 1),
mask=None,
out_features=128,
padding='SAME',
param_dtype=float32,
precision=None,
promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
strides=1,
use_bias=True
),
q=Conv( # Param: 4,160 (16.6 KB)
bias=Param( # 64 (256 B)
value=Array(shape=(64,), dtype=dtype('float32'))
),
bias_init=<function zeros at 0x7fb32b98c2c0>,
conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
dtype=float32,
feature_group_count=1,
in_features=128,
input_dilation=1,
kernel=Param( # 4,096 (16.4 KB)
value=Array(shape=(1, 1, 64, 64), dtype=dtype('float32'))
),
kernel_dilation=1,
kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
kernel_shape=(1, 1, 128, 128),
kernel_size=(1, 1),
mask=None,
out_features=128,
padding='SAME',
param_dtype=float32,
precision=None,
promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
strides=1,
use_bias=True
),
v=Conv( # Param: 4,160 (16.6 KB)
bias=Param( # 64 (256 B)
value=Array(shape=(64,), dtype=dtype('float32'))
),
bias_init=<function zeros at 0x7fb32b98c2c0>,
conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
dtype=float32,
feature_group_count=1,
in_features=128,
input_dilation=1,
kernel=Param( # 4,096 (16.4 KB)
value=Array(shape=(1, 1, 64, 64), dtype=dtype('float32'))
),
kernel_dilation=1,
kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
kernel_shape=(1, 1, 128, 128),
kernel_size=(1, 1),
mask=None,
out_features=128,
padding='SAME',
param_dtype=float32,
precision=None,
promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
strides=1,
use_bias=True
)
)]
), TimestepEmbedSequential( # Param: 36,928 (147.7 KB)
layers=[Downsample( # Param: 36,928 (147.7 KB)
conv=Conv( # Param: 36,928 (147.7 KB)
bias=Param( # 64 (256 B)
value=Array(shape=(64,), dtype=dtype('float32'))
),
bias_init=<function zeros at 0x7fb32b98c2c0>,
conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
dtype=float32,
feature_group_count=1,
in_features=128,
input_dilation=1,
kernel=Param( # 36,864 (147.5 KB)
value=Array(shape=(3, 3, 64, 64), dtype=dtype('float32'))
),
kernel_dilation=1,
kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
kernel_shape=(3, 3, 128, 128),
kernel_size=(3, 3),
mask=None,
out_features=128,
padding=(0, 0),
param_dtype=float32,
precision=None,
promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
strides=(2, 2),
use_bias=True
),
method='conv'
)]
), TimestepEmbedSequential( # Param: 90,624 (362.5 KB), RngState: 2 (12 B), Total: 90,626 (362.5 KB)
layers=[ResnetBlock( # Param: 90,624 (362.5 KB), RngState: 2 (12 B), Total: 90,626 (362.5 KB)
activation=<PjitFunction of <function silu at 0x7fb32b8440e0>>,
conv1=Conv( # Param: 36,928 (147.7 KB)
bias=Param( # 64 (256 B)
value=Array(shape=(64,), dtype=dtype('float32'))
),
bias_init=<function zeros at 0x7fb32b98c2c0>,
conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
dtype=float32,
feature_group_count=1,
in_features=128,
input_dilation=1,
kernel=Param( # 36,864 (147.5 KB)
value=Array(shape=(3, 3, 64, 64), dtype=dtype('float32'))
),
kernel_dilation=1,
kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
kernel_shape=(3, 3, 128, 128),
kernel_size=(3, 3),
mask=None,
out_features=128,
padding=(1, 1),
param_dtype=float32,
precision=None,
promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
strides=(1, 1),
use_bias=True
),
conv2=Conv( # Param: 36,928 (147.7 KB)
bias=Param( # 64 (256 B)
value=Array(shape=(64,), dtype=dtype('float32'))
),
bias_init=<function zeros at 0x7fb32b98c2c0>,
conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
dtype=float32,
feature_group_count=1,
in_features=128,
input_dilation=1,
kernel=Param( # 36,864 (147.5 KB)
value=Array(shape=(3, 3, 64, 64), dtype=dtype('float32'))
),
kernel_dilation=1,
kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
kernel_shape=(3, 3, 128, 128),
kernel_size=(3, 3),
mask=None,
out_features=128,
padding=(1, 1),
param_dtype=float32,
precision=None,
promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
strides=(1, 1),
use_bias=True
),
dropout=Dropout( # RngState: 2 (12 B)
broadcast_dims=(),
deterministic=False,
rate=0.1,
rng_collection='dropout',
rngs=RngStream( # RngState: 2 (12 B)
count=RngCount( # 1 (4 B)
value=Array(0, dtype=uint32),
tag='default'
),
key=RngKey( # 1 (8 B)
value=Array((), dtype=key<fry>) overlaying:
[ 692128146 1043829861],
tag='default'
),
tag='default'
)
),
embedding_dim=256,
in_channels=128,
norm1=GroupNorm( # Param: 128 (512 B)
axis_index_groups=None,
axis_name=None,
bias=Param( # 64 (256 B)
value=Array(shape=(64,), dtype=dtype('float32'))
),
bias_init=<function zeros at 0x7fb32b98c2c0>,
dtype=float32,
epsilon=1e-06,
feature_axis=-1,
group_size=4,
num_groups=32,
param_dtype=float32,
reduction_axes=None,
scale=Param( # 64 (256 B)
value=Array(shape=(64,), dtype=dtype('float32'))
),
scale_init=<function ones at 0x7fb32b86e520>,
use_bias=True,
use_fast_variance=True,
use_scale=True
),
norm2=GroupNorm( # Param: 128 (512 B)
axis_index_groups=None,
axis_name=None,
bias=Param( # 64 (256 B)
value=Array(shape=(64,), dtype=dtype('float32'))
),
bias_init=<function zeros at 0x7fb32b98c2c0>,
dtype=float32,
epsilon=1e-06,
feature_axis=-1,
group_size=4,
num_groups=32,
param_dtype=float32,
reduction_axes=None,
scale=Param( # 64 (256 B)
value=Array(shape=(64,), dtype=dtype('float32'))
),
scale_init=<function ones at 0x7fb32b86e520>,
use_bias=True,
use_fast_variance=True,
use_scale=True
),
out_channels=128,
time_mlp=Linear( # Param: 16,512 (66.0 KB)
bias=Param( # 128 (512 B)
value=Array(shape=(128,), dtype=dtype('float32'))
),
bias_init=<function zeros at 0x7fb32b98c2c0>,
dot_general=<function dot_general at 0x7fb32c21a3e0>,
dtype=float32,
in_features=256,
kernel=Param( # 16,384 (65.5 KB)
value=Array(shape=(128, 128), dtype=dtype('float32'))
),
kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
out_features=256,
param_dtype=float32,
precision=None,
promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
use_bias=True
)
)]
)],
blocks_up=[TimestepEmbedSequential( # Param: 135,872 (543.5 KB), RngState: 2 (12 B), Total: 135,874 (543.5 KB)
layers=[ResnetBlock( # Param: 135,872 (543.5 KB), RngState: 2 (12 B), Total: 135,874 (543.5 KB)
activation=<PjitFunction of <function silu at 0x7fb32b8440e0>>,
conv1=Conv( # Param: 73,792 (295.2 KB)
bias=Param( # 64 (256 B)
value=Array(shape=(64,), dtype=dtype('float32'))
),
bias_init=<function zeros at 0x7fb32b98c2c0>,
conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
dtype=float32,
feature_group_count=1,
in_features=256,
input_dilation=1,
kernel=Param( # 73,728 (294.9 KB)
value=Array(shape=(3, 3, 128, 64), dtype=dtype('float32'))
),
kernel_dilation=1,
kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
kernel_shape=(3, 3, 256, 128),
kernel_size=(3, 3),
mask=None,
out_features=128,
padding=(1, 1),
param_dtype=float32,
precision=None,
promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
strides=(1, 1),
use_bias=True
),
conv2=Conv( # Param: 36,928 (147.7 KB)
bias=Param( # 64 (256 B)
value=Array(shape=(64,), dtype=dtype('float32'))
),
bias_init=<function zeros at 0x7fb32b98c2c0>,
conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
dtype=float32,
feature_group_count=1,
in_features=128,
input_dilation=1,
kernel=Param( # 36,864 (147.5 KB)
value=Array(shape=(3, 3, 64, 64), dtype=dtype('float32'))
),
kernel_dilation=1,
kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
kernel_shape=(3, 3, 128, 128),
kernel_size=(3, 3),
mask=None,
out_features=128,
padding=(1, 1),
param_dtype=float32,
precision=None,
promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
strides=(1, 1),
use_bias=True
),
dropout=Dropout( # RngState: 2 (12 B)
broadcast_dims=(),
deterministic=False,
rate=0.1,
rng_collection='dropout',
rngs=RngStream( # RngState: 2 (12 B)
count=RngCount( # 1 (4 B)
value=Array(0, dtype=uint32),
tag='default'
),
key=RngKey( # 1 (8 B)
value=Array((), dtype=key<fry>) overlaying:
[2853902436 2217684095],
tag='default'
),
tag='default'
)
),
embedding_dim=256,
in_channels=256,
nin_shortcut=Conv( # Param: 8,256 (33.0 KB)
bias=Param( # 64 (256 B)
value=Array(shape=(64,), dtype=dtype('float32'))
),
bias_init=<function zeros at 0x7fb32b98c2c0>,
conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
dtype=float32,
feature_group_count=1,
in_features=256,
input_dilation=1,
kernel=Param( # 8,192 (32.8 KB)
value=Array(shape=(1, 1, 128, 64), dtype=dtype('float32'))
),
kernel_dilation=1,
kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
kernel_shape=(1, 1, 256, 128),
kernel_size=(1, 1),
mask=None,
out_features=128,
padding=(0, 0),
param_dtype=float32,
precision=None,
promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
strides=(1, 1),
use_bias=True
),
norm1=GroupNorm( # Param: 256 (1.0 KB)
axis_index_groups=None,
axis_name=None,
bias=Param( # 128 (512 B)
value=Array(shape=(128,), dtype=dtype('float32'))
),
bias_init=<function zeros at 0x7fb32b98c2c0>,
dtype=float32,
epsilon=1e-06,
feature_axis=-1,
group_size=8,
num_groups=32,
param_dtype=float32,
reduction_axes=None,
scale=Param( # 128 (512 B)
value=Array(shape=(128,), dtype=dtype('float32'))
),
scale_init=<function ones at 0x7fb32b86e520>,
use_bias=True,
use_fast_variance=True,
use_scale=True
),
norm2=GroupNorm( # Param: 128 (512 B)
axis_index_groups=None,
axis_name=None,
bias=Param( # 64 (256 B)
value=Array(shape=(64,), dtype=dtype('float32'))
),
bias_init=<function zeros at 0x7fb32b98c2c0>,
dtype=float32,
epsilon=1e-06,
feature_axis=-1,
group_size=4,
num_groups=32,
param_dtype=float32,
reduction_axes=None,
scale=Param( # 64 (256 B)
value=Array(shape=(64,), dtype=dtype('float32'))
),
scale_init=<function ones at 0x7fb32b86e520>,
use_bias=True,
use_fast_variance=True,
use_scale=True
),
out_channels=128,
time_mlp=Linear( # Param: 16,512 (66.0 KB)
bias=Param( # 128 (512 B)
value=Array(shape=(128,), dtype=dtype('float32'))
),
bias_init=<function zeros at 0x7fb32b98c2c0>,
dot_general=<function dot_general at 0x7fb32c21a3e0>,
dtype=float32,
in_features=256,
kernel=Param( # 16,384 (65.5 KB)
value=Array(shape=(128, 128), dtype=dtype('float32'))
),
kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
out_features=256,
param_dtype=float32,
precision=None,
promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
use_bias=True
)
)]
), TimestepEmbedSequential( # Param: 320,512 (1.3 MB), RngState: 2 (12 B), Total: 320,514 (1.3 MB)
layers=[ResnetBlock( # Param: 135,872 (543.5 KB), RngState: 2 (12 B), Total: 135,874 (543.5 KB)
activation=<PjitFunction of <function silu at 0x7fb32b8440e0>>,
conv1=Conv( # Param: 73,792 (295.2 KB)
bias=Param( # 64 (256 B)
value=Array(shape=(64,), dtype=dtype('float32'))
),
bias_init=<function zeros at 0x7fb32b98c2c0>,
conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
dtype=float32,
feature_group_count=1,
in_features=256,
input_dilation=1,
kernel=Param( # 73,728 (294.9 KB)
value=Array(shape=(3, 3, 128, 64), dtype=dtype('float32'))
),
kernel_dilation=1,
kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
kernel_shape=(3, 3, 256, 128),
kernel_size=(3, 3),
mask=None,
out_features=128,
padding=(1, 1),
param_dtype=float32,
precision=None,
promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
strides=(1, 1),
use_bias=True
),
conv2=Conv( # Param: 36,928 (147.7 KB)
bias=Param( # 64 (256 B)
value=Array(shape=(64,), dtype=dtype('float32'))
),
bias_init=<function zeros at 0x7fb32b98c2c0>,
conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
dtype=float32,
feature_group_count=1,
in_features=128,
input_dilation=1,
kernel=Param( # 36,864 (147.5 KB)
value=Array(shape=(3, 3, 64, 64), dtype=dtype('float32'))
),
kernel_dilation=1,
kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
kernel_shape=(3, 3, 128, 128),
kernel_size=(3, 3),
mask=None,
out_features=128,
padding=(1, 1),
param_dtype=float32,
precision=None,
promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
strides=(1, 1),
use_bias=True
),
dropout=Dropout( # RngState: 2 (12 B)
broadcast_dims=(),
deterministic=False,
rate=0.1,
rng_collection='dropout',
rngs=RngStream( # RngState: 2 (12 B)
count=RngCount( # 1 (4 B)
value=Array(0, dtype=uint32),
tag='default'
),
key=RngKey( # 1 (8 B)
value=Array((), dtype=key<fry>) overlaying:
[2785098898 841100811],
tag='default'
),
tag='default'
)
),
embedding_dim=256,
in_channels=256,
nin_shortcut=Conv( # Param: 8,256 (33.0 KB)
bias=Param( # 64 (256 B)
value=Array(shape=(64,), dtype=dtype('float32'))
),
bias_init=<function zeros at 0x7fb32b98c2c0>,
conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
dtype=float32,
feature_group_count=1,
in_features=256,
input_dilation=1,
kernel=Param( # 8,192 (32.8 KB)
value=Array(shape=(1, 1, 128, 64), dtype=dtype('float32'))
),
kernel_dilation=1,
kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
kernel_shape=(1, 1, 256, 128),
kernel_size=(1, 1),
mask=None,
out_features=128,
padding=(0, 0),
param_dtype=float32,
precision=None,
promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
strides=(1, 1),
use_bias=True
),
norm1=GroupNorm( # Param: 256 (1.0 KB)
axis_index_groups=None,
axis_name=None,
bias=Param( # 128 (512 B)
value=Array(shape=(128,), dtype=dtype('float32'))
),
bias_init=<function zeros at 0x7fb32b98c2c0>,
dtype=float32,
epsilon=1e-06,
feature_axis=-1,
group_size=8,
num_groups=32,
param_dtype=float32,
reduction_axes=None,
scale=Param( # 128 (512 B)
value=Array(shape=(128,), dtype=dtype('float32'))
),
scale_init=<function ones at 0x7fb32b86e520>,
use_bias=True,
use_fast_variance=True,
use_scale=True
),
norm2=GroupNorm( # Param: 128 (512 B)
axis_index_groups=None,
axis_name=None,
bias=Param( # 64 (256 B)
value=Array(shape=(64,), dtype=dtype('float32'))
),
bias_init=<function zeros at 0x7fb32b98c2c0>,
dtype=float32,
epsilon=1e-06,
feature_axis=-1,
group_size=4,
num_groups=32,
param_dtype=float32,
reduction_axes=None,
scale=Param( # 64 (256 B)
value=Array(shape=(64,), dtype=dtype('float32'))
),
scale_init=<function ones at 0x7fb32b86e520>,
use_bias=True,
use_fast_variance=True,
use_scale=True
),
out_channels=128,
time_mlp=Linear( # Param: 16,512 (66.0 KB)
bias=Param( # 128 (512 B)
value=Array(shape=(128,), dtype=dtype('float32'))
),
bias_init=<function zeros at 0x7fb32b98c2c0>,
dot_general=<function dot_general at 0x7fb32c21a3e0>,
dtype=float32,
in_features=256,
kernel=Param( # 16,384 (65.5 KB)
value=Array(shape=(128, 128), dtype=dtype('float32'))
),
kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
out_features=256,
param_dtype=float32,
precision=None,
promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
use_bias=True
)
), Upsample( # Param: 184,640 (738.6 KB)
conv_pixel_shuffle=Conv( # Param: 147,712 (590.8 KB)
bias=Param( # 256 (1.0 KB)
value=Array(shape=(256,), dtype=dtype('float32'))
),
bias_init=<function zeros at 0x7fb32b98c2c0>,
conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
dtype=float32,
feature_group_count=1,
in_features=128,
input_dilation=1,
kernel=Param( # 147,456 (589.8 KB)
value=Array(shape=(3, 3, 64, 256), dtype=dtype('float32'))
),
kernel_dilation=1,
kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
kernel_shape=(3, 3, 128, 512),
kernel_size=(3, 3),
mask=None,
out_features=512,
padding=(1, 1),
param_dtype=float32,
precision=None,
promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
strides=(1, 1),
use_bias=True
),
conv_resize=Conv( # Param: 36,928 (147.7 KB)
bias=Param( # 64 (256 B)
value=Array(shape=(64,), dtype=dtype('float32'))
),
bias_init=<function zeros at 0x7fb32b98c2c0>,
conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
dtype=float32,
feature_group_count=1,
in_features=128,
input_dilation=1,
kernel=Param( # 36,864 (147.5 KB)
value=Array(shape=(3, 3, 64, 64), dtype=dtype('float32'))
),
kernel_dilation=1,
kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
kernel_shape=(3, 3, 128, 128),
kernel_size=(3, 3),
mask=None,
out_features=128,
padding=(1, 1),
param_dtype=float32,
precision=None,
promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
strides=(1, 1),
use_bias=True
),
method='pixel_shuffle',
pixel_shuffle=PixelShuffle(
scale=2
),
scale_factor=2
)]
), TimestepEmbedSequential( # Param: 152,640 (610.6 KB), RngState: 2 (12 B), Total: 152,642 (610.6 KB)
layers=[ResnetBlock( # Param: 135,872 (543.5 KB), RngState: 2 (12 B), Total: 135,874 (543.5 KB)
activation=<PjitFunction of <function silu at 0x7fb32b8440e0>>,
conv1=Conv( # Param: 73,792 (295.2 KB)
bias=Param( # 64 (256 B)
value=Array(shape=(64,), dtype=dtype('float32'))
),
bias_init=<function zeros at 0x7fb32b98c2c0>,
conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
dtype=float32,
feature_group_count=1,
in_features=256,
input_dilation=1,
kernel=Param( # 73,728 (294.9 KB)
value=Array(shape=(3, 3, 128, 64), dtype=dtype('float32'))
),
kernel_dilation=1,
kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
kernel_shape=(3, 3, 256, 128),
kernel_size=(3, 3),
mask=None,
out_features=128,
padding=(1, 1),
param_dtype=float32,
precision=None,
promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
strides=(1, 1),
use_bias=True
),
conv2=Conv( # Param: 36,928 (147.7 KB)
bias=Param( # 64 (256 B)
value=Array(shape=(64,), dtype=dtype('float32'))
),
bias_init=<function zeros at 0x7fb32b98c2c0>,
conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
dtype=float32,
feature_group_count=1,
in_features=128,
input_dilation=1,
kernel=Param( # 36,864 (147.5 KB)
value=Array(shape=(3, 3, 64, 64), dtype=dtype('float32'))
),
kernel_dilation=1,
kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
kernel_shape=(3, 3, 128, 128),
kernel_size=(3, 3),
mask=None,
out_features=128,
padding=(1, 1),
param_dtype=float32,
precision=None,
promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
strides=(1, 1),
use_bias=True
),
dropout=Dropout( # RngState: 2 (12 B)
broadcast_dims=(),
deterministic=False,
rate=0.1,
rng_collection='dropout',
rngs=RngStream( # RngState: 2 (12 B)
count=RngCount( # 1 (4 B)
value=Array(0, dtype=uint32),
tag='default'
),
key=RngKey( # 1 (8 B)
value=Array((), dtype=key<fry>) overlaying:
[ 48802331 1548237274],
tag='default'
),
tag='default'
)
),
embedding_dim=256,
in_channels=256,
nin_shortcut=Conv( # Param: 8,256 (33.0 KB)
bias=Param( # 64 (256 B)
value=Array(shape=(64,), dtype=dtype('float32'))
),
bias_init=<function zeros at 0x7fb32b98c2c0>,
conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
dtype=float32,
feature_group_count=1,
in_features=256,
input_dilation=1,
kernel=Param( # 8,192 (32.8 KB)
value=Array(shape=(1, 1, 128, 64), dtype=dtype('float32'))
),
kernel_dilation=1,
kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
kernel_shape=(1, 1, 256, 128),
kernel_size=(1, 1),
mask=None,
out_features=128,
padding=(0, 0),
param_dtype=float32,
precision=None,
promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
strides=(1, 1),
use_bias=True
),
norm1=GroupNorm( # Param: 256 (1.0 KB)
axis_index_groups=None,
axis_name=None,
bias=Param( # 128 (512 B)
value=Array(shape=(128,), dtype=dtype('float32'))
),
bias_init=<function zeros at 0x7fb32b98c2c0>,
dtype=float32,
epsilon=1e-06,
feature_axis=-1,
group_size=8,
num_groups=32,
param_dtype=float32,
reduction_axes=None,
scale=Param( # 128 (512 B)
value=Array(shape=(128,), dtype=dtype('float32'))
),
scale_init=<function ones at 0x7fb32b86e520>,
use_bias=True,
use_fast_variance=True,
use_scale=True
),
norm2=GroupNorm( # Param: 128 (512 B)
axis_index_groups=None,
axis_name=None,
bias=Param( # 64 (256 B)
value=Array(shape=(64,), dtype=dtype('float32'))
),
bias_init=<function zeros at 0x7fb32b98c2c0>,
dtype=float32,
epsilon=1e-06,
feature_axis=-1,
group_size=4,
num_groups=32,
param_dtype=float32,
reduction_axes=None,
scale=Param( # 64 (256 B)
value=Array(shape=(64,), dtype=dtype('float32'))
),
scale_init=<function ones at 0x7fb32b86e520>,
use_bias=True,
use_fast_variance=True,
use_scale=True
),
out_channels=128,
time_mlp=Linear( # Param: 16,512 (66.0 KB)
bias=Param( # 128 (512 B)
value=Array(shape=(128,), dtype=dtype('float32'))
),
bias_init=<function zeros at 0x7fb32b98c2c0>,
dot_general=<function dot_general at 0x7fb32c21a3e0>,
dtype=float32,
in_features=256,
kernel=Param( # 16,384 (65.5 KB)
value=Array(shape=(128, 128), dtype=dtype('float32'))
),
kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
out_features=256,
param_dtype=float32,
precision=None,
promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
use_bias=True
)
), AttnBlock( # Param: 16,768 (67.1 KB)
dtype=float32,
head_dim=32,
k=Conv( # Param: 4,160 (16.6 KB)
bias=Param( # 64 (256 B)
value=Array(shape=(64,), dtype=dtype('float32'))
),
bias_init=<function zeros at 0x7fb32b98c2c0>,
conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
dtype=float32,
feature_group_count=1,
in_features=128,
input_dilation=1,
kernel=Param( # 4,096 (16.4 KB)
value=Array(shape=(1, 1, 64, 64), dtype=dtype('float32'))
),
kernel_dilation=1,
kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
kernel_shape=(1, 1, 128, 128),
kernel_size=(1, 1),
mask=None,
out_features=128,
padding='SAME',
param_dtype=float32,
precision=None,
promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
strides=1,
use_bias=True
),
norm=GroupNorm( # Param: 128 (512 B)
axis_index_groups=None,
axis_name=None,
bias=Param( # 64 (256 B)
value=Array(shape=(64,), dtype=dtype('float32'))
),
bias_init=<function zeros at 0x7fb32b98c2c0>,
dtype=float32,
epsilon=1e-06,
feature_axis=-1,
group_size=4,
num_groups=32,
param_dtype=float32,
reduction_axes=None,
scale=Param( # 64 (256 B)
value=Array(shape=(64,), dtype=dtype('float32'))
),
scale_init=<function ones at 0x7fb32b86e520>,
use_bias=True,
use_fast_variance=True,
use_scale=True
),
num_heads=4,
proj_out=Conv( # Param: 4,160 (16.6 KB)
bias=Param( # 64 (256 B)
value=Array(shape=(64,), dtype=dtype('float32'))
),
bias_init=<function zeros at 0x7fb32b98c2c0>,
conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
dtype=float32,
feature_group_count=1,
in_features=128,
input_dilation=1,
kernel=Param( # 4,096 (16.4 KB)
value=Array(shape=(1, 1, 64, 64), dtype=dtype('float32'))
),
kernel_dilation=1,
kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
kernel_shape=(1, 1, 128, 128),
kernel_size=(1, 1),
mask=None,
out_features=128,
padding='SAME',
param_dtype=float32,
precision=None,
promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
strides=1,
use_bias=True
),
q=Conv( # Param: 4,160 (16.6 KB)
bias=Param( # 64 (256 B)
value=Array(shape=(64,), dtype=dtype('float32'))
),
bias_init=<function zeros at 0x7fb32b98c2c0>,
conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
dtype=float32,
feature_group_count=1,
in_features=128,
input_dilation=1,
kernel=Param( # 4,096 (16.4 KB)
value=Array(shape=(1, 1, 64, 64), dtype=dtype('float32'))
),
kernel_dilation=1,
kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
kernel_shape=(1, 1, 128, 128),
kernel_size=(1, 1),
mask=None,
out_features=128,
padding='SAME',
param_dtype=float32,
precision=None,
promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
strides=1,
use_bias=True
),
v=Conv( # Param: 4,160 (16.6 KB)
bias=Param( # 64 (256 B)
value=Array(shape=(64,), dtype=dtype('float32'))
),
bias_init=<function zeros at 0x7fb32b98c2c0>,
conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
dtype=float32,
feature_group_count=1,
in_features=128,
input_dilation=1,
kernel=Param( # 4,096 (16.4 KB)
value=Array(shape=(1, 1, 64, 64), dtype=dtype('float32'))
),
kernel_dilation=1,
kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
kernel_shape=(1, 1, 128, 128),
kernel_size=(1, 1),
mask=None,
out_features=128,
padding='SAME',
param_dtype=float32,
precision=None,
promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
strides=1,
use_bias=True
)
)]
), TimestepEmbedSequential( # Param: 316,736 (1.3 MB), RngState: 2 (12 B), Total: 316,738 (1.3 MB)
layers=[ResnetBlock( # Param: 115,328 (461.3 KB), RngState: 2 (12 B), Total: 115,330 (461.3 KB)
activation=<PjitFunction of <function silu at 0x7fb32b8440e0>>,
conv1=Conv( # Param: 55,360 (221.4 KB)
bias=Param( # 64 (256 B)
value=Array(shape=(64,), dtype=dtype('float32'))
),
bias_init=<function zeros at 0x7fb32b98c2c0>,
conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
dtype=float32,
feature_group_count=1,
in_features=192,
input_dilation=1,
kernel=Param( # 55,296 (221.2 KB)
value=Array(shape=(3, 3, 96, 64), dtype=dtype('float32'))
),
kernel_dilation=1,
kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
kernel_shape=(3, 3, 192, 128),
kernel_size=(3, 3),
mask=None,
out_features=128,
padding=(1, 1),
param_dtype=float32,
precision=None,
promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
strides=(1, 1),
use_bias=True
),
conv2=Conv( # Param: 36,928 (147.7 KB)
bias=Param( # 64 (256 B)
value=Array(shape=(64,), dtype=dtype('float32'))
),
bias_init=<function zeros at 0x7fb32b98c2c0>,
conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
dtype=float32,
feature_group_count=1,
in_features=128,
input_dilation=1,
kernel=Param( # 36,864 (147.5 KB)
value=Array(shape=(3, 3, 64, 64), dtype=dtype('float32'))
),
kernel_dilation=1,
kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
kernel_shape=(3, 3, 128, 128),
kernel_size=(3, 3),
mask=None,
out_features=128,
padding=(1, 1),
param_dtype=float32,
precision=None,
promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
strides=(1, 1),
use_bias=True
),
dropout=Dropout( # RngState: 2 (12 B)
broadcast_dims=(),
deterministic=False,
rate=0.1,
rng_collection='dropout',
rngs=RngStream( # RngState: 2 (12 B)
count=RngCount( # 1 (4 B)
value=Array(0, dtype=uint32),
tag='default'
),
key=RngKey( # 1 (8 B)
value=Array((), dtype=key<fry>) overlaying:
[1596966061 1315822572],
tag='default'
),
tag='default'
)
),
embedding_dim=256,
in_channels=192,
nin_shortcut=Conv( # Param: 6,208 (24.8 KB)
bias=Param( # 64 (256 B)
value=Array(shape=(64,), dtype=dtype('float32'))
),
bias_init=<function zeros at 0x7fb32b98c2c0>,
conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
dtype=float32,
feature_group_count=1,
in_features=192,
input_dilation=1,
kernel=Param( # 6,144 (24.6 KB)
value=Array(shape=(1, 1, 96, 64), dtype=dtype('float32'))
),
kernel_dilation=1,
kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
kernel_shape=(1, 1, 192, 128),
kernel_size=(1, 1),
mask=None,
out_features=128,
padding=(0, 0),
param_dtype=float32,
precision=None,
promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
strides=(1, 1),
use_bias=True
),
norm1=GroupNorm( # Param: 192 (768 B)
axis_index_groups=None,
axis_name=None,
bias=Param( # 96 (384 B)
value=Array(shape=(96,), dtype=dtype('float32'))
),
bias_init=<function zeros at 0x7fb32b98c2c0>,
dtype=float32,
epsilon=1e-06,
feature_axis=-1,
group_size=6,
num_groups=32,
param_dtype=float32,
reduction_axes=None,
scale=Param( # 96 (384 B)
value=Array(shape=(96,), dtype=dtype('float32'))
),
scale_init=<function ones at 0x7fb32b86e520>,
use_bias=True,
use_fast_variance=True,
use_scale=True
),
norm2=GroupNorm( # Param: 128 (512 B)
axis_index_groups=None,
axis_name=None,
bias=Param( # 64 (256 B)
value=Array(shape=(64,), dtype=dtype('float32'))
),
bias_init=<function zeros at 0x7fb32b98c2c0>,
dtype=float32,
epsilon=1e-06,
feature_axis=-1,
group_size=4,
num_groups=32,
param_dtype=float32,
reduction_axes=None,
scale=Param( # 64 (256 B)
value=Array(shape=(64,), dtype=dtype('float32'))
),
scale_init=<function ones at 0x7fb32b86e520>,
use_bias=True,
use_fast_variance=True,
use_scale=True
),
out_channels=128,
time_mlp=Linear( # Param: 16,512 (66.0 KB)
bias=Param( # 128 (512 B)
value=Array(shape=(128,), dtype=dtype('float32'))
),
bias_init=<function zeros at 0x7fb32b98c2c0>,
dot_general=<function dot_general at 0x7fb32c21a3e0>,
dtype=float32,
in_features=256,
kernel=Param( # 16,384 (65.5 KB)
value=Array(shape=(128, 128), dtype=dtype('float32'))
),
kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
out_features=256,
param_dtype=float32,
precision=None,
promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
use_bias=True
)
), AttnBlock( # Param: 16,768 (67.1 KB)
dtype=float32,
head_dim=32,
k=Conv( # Param: 4,160 (16.6 KB)
bias=Param( # 64 (256 B)
value=Array(shape=(64,), dtype=dtype('float32'))
),
bias_init=<function zeros at 0x7fb32b98c2c0>,
conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
dtype=float32,
feature_group_count=1,
in_features=128,
input_dilation=1,
kernel=Param( # 4,096 (16.4 KB)
value=Array(shape=(1, 1, 64, 64), dtype=dtype('float32'))
),
kernel_dilation=1,
kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
kernel_shape=(1, 1, 128, 128),
kernel_size=(1, 1),
mask=None,
out_features=128,
padding='SAME',
param_dtype=float32,
precision=None,
promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
strides=1,
use_bias=True
),
norm=GroupNorm( # Param: 128 (512 B)
axis_index_groups=None,
axis_name=None,
bias=Param( # 64 (256 B)
value=Array(shape=(64,), dtype=dtype('float32'))
),
bias_init=<function zeros at 0x7fb32b98c2c0>,
dtype=float32,
epsilon=1e-06,
feature_axis=-1,
group_size=4,
num_groups=32,
param_dtype=float32,
reduction_axes=None,
scale=Param( # 64 (256 B)
value=Array(shape=(64,), dtype=dtype('float32'))
),
scale_init=<function ones at 0x7fb32b86e520>,
use_bias=True,
use_fast_variance=True,
use_scale=True
),
num_heads=4,
proj_out=Conv( # Param: 4,160 (16.6 KB)
bias=Param( # 64 (256 B)
value=Array(shape=(64,), dtype=dtype('float32'))
),
bias_init=<function zeros at 0x7fb32b98c2c0>,
conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
dtype=float32,
feature_group_count=1,
in_features=128,
input_dilation=1,
kernel=Param( # 4,096 (16.4 KB)
value=Array(shape=(1, 1, 64, 64), dtype=dtype('float32'))
),
kernel_dilation=1,
kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
kernel_shape=(1, 1, 128, 128),
kernel_size=(1, 1),
mask=None,
out_features=128,
padding='SAME',
param_dtype=float32,
precision=None,
promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
strides=1,
use_bias=True
),
q=Conv( # Param: 4,160 (16.6 KB)
bias=Param( # 64 (256 B)
value=Array(shape=(64,), dtype=dtype('float32'))
),
bias_init=<function zeros at 0x7fb32b98c2c0>,
conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
dtype=float32,
feature_group_count=1,
in_features=128,
input_dilation=1,
kernel=Param( # 4,096 (16.4 KB)
value=Array(shape=(1, 1, 64, 64), dtype=dtype('float32'))
),
kernel_dilation=1,
kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
kernel_shape=(1, 1, 128, 128),
kernel_size=(1, 1),
mask=None,
out_features=128,
padding='SAME',
param_dtype=float32,
precision=None,
promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
strides=1,
use_bias=True
),
v=Conv( # Param: 4,160 (16.6 KB)
bias=Param( # 64 (256 B)
value=Array(shape=(64,), dtype=dtype('float32'))
),
bias_init=<function zeros at 0x7fb32b98c2c0>,
conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
dtype=float32,
feature_group_count=1,
in_features=128,
input_dilation=1,
kernel=Param( # 4,096 (16.4 KB)
value=Array(shape=(1, 1, 64, 64), dtype=dtype('float32'))
),
kernel_dilation=1,
kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
kernel_shape=(1, 1, 128, 128),
kernel_size=(1, 1),
mask=None,
out_features=128,
padding='SAME',
param_dtype=float32,
precision=None,
promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
strides=1,
use_bias=True
)
), Upsample( # Param: 184,640 (738.6 KB)
conv_pixel_shuffle=Conv( # Param: 147,712 (590.8 KB)
bias=Param( # 256 (1.0 KB)
value=Array(shape=(256,), dtype=dtype('float32'))
),
bias_init=<function zeros at 0x7fb32b98c2c0>,
conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
dtype=float32,
feature_group_count=1,
in_features=128,
input_dilation=1,
kernel=Param( # 147,456 (589.8 KB)
value=Array(shape=(3, 3, 64, 256), dtype=dtype('float32'))
),
kernel_dilation=1,
kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
kernel_shape=(3, 3, 128, 512),
kernel_size=(3, 3),
mask=None,
out_features=512,
padding=(1, 1),
param_dtype=float32,
precision=None,
promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
strides=(1, 1),
use_bias=True
),
conv_resize=Conv( # Param: 36,928 (147.7 KB)
bias=Param( # 64 (256 B)
value=Array(shape=(64,), dtype=dtype('float32'))
),
bias_init=<function zeros at 0x7fb32b98c2c0>,
conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
dtype=float32,
feature_group_count=1,
in_features=128,
input_dilation=1,
kernel=Param( # 36,864 (147.5 KB)
value=Array(shape=(3, 3, 64, 64), dtype=dtype('float32'))
),
kernel_dilation=1,
kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
kernel_shape=(3, 3, 128, 128),
kernel_size=(3, 3),
mask=None,
out_features=128,
padding=(1, 1),
param_dtype=float32,
precision=None,
promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
strides=(1, 1),
use_bias=True
),
method='pixel_shuffle',
pixel_shuffle=PixelShuffle(
scale=2
),
scale_factor=2
)]
), TimestepEmbedSequential( # Param: 48,544 (194.2 KB), RngState: 2 (12 B), Total: 48,546 (194.2 KB)
layers=[ResnetBlock( # Param: 48,544 (194.2 KB), RngState: 2 (12 B), Total: 48,546 (194.2 KB)
activation=<PjitFunction of <function silu at 0x7fb32b8440e0>>,
conv1=Conv( # Param: 27,680 (110.7 KB)
bias=Param( # 32 (128 B)
value=Array(shape=(32,), dtype=dtype('float32'))
),
bias_init=<function zeros at 0x7fb32b98c2c0>,
conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
dtype=float32,
feature_group_count=1,
in_features=192,
input_dilation=1,
kernel=Param( # 27,648 (110.6 KB)
value=Array(shape=(3, 3, 96, 32), dtype=dtype('float32'))
),
kernel_dilation=1,
kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
kernel_shape=(3, 3, 192, 64),
kernel_size=(3, 3),
mask=None,
out_features=64,
padding=(1, 1),
param_dtype=float32,
precision=None,
promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
strides=(1, 1),
use_bias=True
),
conv2=Conv( # Param: 9,248 (37.0 KB)
bias=Param( # 32 (128 B)
value=Array(shape=(32,), dtype=dtype('float32'))
),
bias_init=<function zeros at 0x7fb32b98c2c0>,
conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
dtype=float32,
feature_group_count=1,
in_features=64,
input_dilation=1,
kernel=Param( # 9,216 (36.9 KB)
value=Array(shape=(3, 3, 32, 32), dtype=dtype('float32'))
),
kernel_dilation=1,
kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
kernel_shape=(3, 3, 64, 64),
kernel_size=(3, 3),
mask=None,
out_features=64,
padding=(1, 1),
param_dtype=float32,
precision=None,
promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
strides=(1, 1),
use_bias=True
),
dropout=Dropout( # RngState: 2 (12 B)
broadcast_dims=(),
deterministic=False,
rate=0.1,
rng_collection='dropout',
rngs=RngStream( # RngState: 2 (12 B)
count=RngCount( # 1 (4 B)
value=Array(0, dtype=uint32),
tag='default'
),
key=RngKey( # 1 (8 B)
value=Array((), dtype=key<fry>) overlaying:
[2550820645 2818876438],
tag='default'
),
tag='default'
)
),
embedding_dim=256,
in_channels=192,
nin_shortcut=Conv( # Param: 3,104 (12.4 KB)
bias=Param( # 32 (128 B)
value=Array(shape=(32,), dtype=dtype('float32'))
),
bias_init=<function zeros at 0x7fb32b98c2c0>,
conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
dtype=float32,
feature_group_count=1,
in_features=192,
input_dilation=1,
kernel=Param( # 3,072 (12.3 KB)
value=Array(shape=(1, 1, 96, 32), dtype=dtype('float32'))
),
kernel_dilation=1,
kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
kernel_shape=(1, 1, 192, 64),
kernel_size=(1, 1),
mask=None,
out_features=64,
padding=(0, 0),
param_dtype=float32,
precision=None,
promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
strides=(1, 1),
use_bias=True
),
norm1=GroupNorm( # Param: 192 (768 B)
axis_index_groups=None,
axis_name=None,
bias=Param( # 96 (384 B)
value=Array(shape=(96,), dtype=dtype('float32'))
),
bias_init=<function zeros at 0x7fb32b98c2c0>,
dtype=float32,
epsilon=1e-06,
feature_axis=-1,
group_size=6,
num_groups=32,
param_dtype=float32,
reduction_axes=None,
scale=Param( # 96 (384 B)
value=Array(shape=(96,), dtype=dtype('float32'))
),
scale_init=<function ones at 0x7fb32b86e520>,
use_bias=True,
use_fast_variance=True,
use_scale=True
),
norm2=GroupNorm( # Param: 64 (256 B)
axis_index_groups=None,
axis_name=None,
bias=Param( # 32 (128 B)
value=Array(shape=(32,), dtype=dtype('float32'))
),
bias_init=<function zeros at 0x7fb32b98c2c0>,
dtype=float32,
epsilon=1e-06,
feature_axis=-1,
group_size=2,
num_groups=32,
param_dtype=float32,
reduction_axes=None,
scale=Param( # 32 (128 B)
value=Array(shape=(32,), dtype=dtype('float32'))
),
scale_init=<function ones at 0x7fb32b86e520>,
use_bias=True,
use_fast_variance=True,
use_scale=True
),
out_channels=64,
time_mlp=Linear( # Param: 8,256 (33.0 KB)
bias=Param( # 64 (256 B)
value=Array(shape=(64,), dtype=dtype('float32'))
),
bias_init=<function zeros at 0x7fb32b98c2c0>,
dot_general=<function dot_general at 0x7fb32c21a3e0>,
dtype=float32,
in_features=256,
kernel=Param( # 8,192 (32.8 KB)
value=Array(shape=(128, 64), dtype=dtype('float32'))
),
kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
out_features=128,
param_dtype=float32,
precision=None,
promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
use_bias=True
)
)]
), TimestepEmbedSequential( # Param: 38,240 (153.0 KB), RngState: 2 (12 B), Total: 38,242 (153.0 KB)
layers=[ResnetBlock( # Param: 38,240 (153.0 KB), RngState: 2 (12 B), Total: 38,242 (153.0 KB)
activation=<PjitFunction of <function silu at 0x7fb32b8440e0>>,
conv1=Conv( # Param: 18,464 (73.9 KB)
bias=Param( # 32 (128 B)
value=Array(shape=(32,), dtype=dtype('float32'))
),
bias_init=<function zeros at 0x7fb32b98c2c0>,
conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
dtype=float32,
feature_group_count=1,
in_features=128,
input_dilation=1,
kernel=Param( # 18,432 (73.7 KB)
value=Array(shape=(3, 3, 64, 32), dtype=dtype('float32'))
),
kernel_dilation=1,
kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
kernel_shape=(3, 3, 128, 64),
kernel_size=(3, 3),
mask=None,
out_features=64,
padding=(1, 1),
param_dtype=float32,
precision=None,
promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
strides=(1, 1),
use_bias=True
),
conv2=Conv( # Param: 9,248 (37.0 KB)
bias=Param( # 32 (128 B)
value=Array(shape=(32,), dtype=dtype('float32'))
),
bias_init=<function zeros at 0x7fb32b98c2c0>,
conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
dtype=float32,
feature_group_count=1,
in_features=64,
input_dilation=1,
kernel=Param( # 9,216 (36.9 KB)
value=Array(shape=(3, 3, 32, 32), dtype=dtype('float32'))
),
kernel_dilation=1,
kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
kernel_shape=(3, 3, 64, 64),
kernel_size=(3, 3),
mask=None,
out_features=64,
padding=(1, 1),
param_dtype=float32,
precision=None,
promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
strides=(1, 1),
use_bias=True
),
dropout=Dropout( # RngState: 2 (12 B)
broadcast_dims=(),
deterministic=False,
rate=0.1,
rng_collection='dropout',
rngs=RngStream( # RngState: 2 (12 B)
count=RngCount( # 1 (4 B)
value=Array(0, dtype=uint32),
tag='default'
),
key=RngKey( # 1 (8 B)
value=Array((), dtype=key<fry>) overlaying:
[1975238715 3717004500],
tag='default'
),
tag='default'
)
),
embedding_dim=256,
in_channels=128,
nin_shortcut=Conv( # Param: 2,080 (8.3 KB)
bias=Param( # 32 (128 B)
value=Array(shape=(32,), dtype=dtype('float32'))
),
bias_init=<function zeros at 0x7fb32b98c2c0>,
conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>,
dtype=float32,
feature_group_count=1,
in_features=128,
input_dilation=1,
kernel=Param( # 2,048 (8.2 KB)
value=Array(shape=(1, 1, 64, 32), dtype=dtype('float32'))
),
kernel_dilation=1,
kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
kernel_shape=(1, 1, 128, 64),
kernel_size=(1, 1),
mask=None,
out_features=64,
padding=(0, 0),
param_dtype=float32,
precision=None,
promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
strides=(1, 1),
use_bias=True
),
norm1=GroupNorm( # Param: 128 (512 B)
axis_index_groups=None,
axis_name=None,
bias=Param( # 64 (256 B)
value=Array(shape=(64,), dtype=dtype('float32'))
),
bias_init=<function zeros at 0x7fb32b98c2c0>,
dtype=float32,
epsilon=1e-06,
feature_axis=-1,
group_size=4,
num_groups=32,
param_dtype=float32,
reduction_axes=None,
scale=Param( # 64 (256 B)
value=Array(shape=(64,), dtype=dtype('float32'))
),
scale_init=<function ones at 0x7fb32b86e520>,
use_bias=True,
use_fast_variance=True,
use_scale=True
),
norm2=GroupNorm( # Param: 64 (256 B)
axis_index_groups=None,
axis_name=None,
bias=Param( # 32 (128 B)
value=Array(shape=(32,), dtype=dtype('float32'))
),
bias_init=<function zeros at 0x7fb32b98c2c0>,
dtype=float32,
epsilon=1e-06,
feature_axis=-1,
group_size=2,
num_groups=32,
param_dtype=float32,
reduction_axes=None,
scale=Param( # 32 (128 B)
value=Array(shape=(32,), dtype=dtype('float32'))
),
scale_init=<function ones at 0x7fb32b86e520>,
use_bias=True,
use_fast_variance=True,
use_scale=True
),
out_channels=64,
time_mlp=Linear( # Param: 8,256 (33.0 KB)
bias=Param( # 64 (256 B)
value=Array(shape=(64,), dtype=dtype('float32'))
),
bias_init=<function zeros at 0x7fb32b98c2c0>,
dot_general=<function dot_general at 0x7fb32c21a3e0>,
dtype=float32,
in_features=256,
kernel=Param( # 8,192 (32.8 KB)
value=Array(shape=(128, 64), dtype=dtype('float32'))
),
kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>,
out_features=128,
param_dtype=float32,
precision=None,
promote_dtype=<function promote_dtype at 0x7fb3299d7380>,
use_bias=True
)
)]
)],
)