| """Model configuration for jcopo/mnist |
| |
| This file contains the model architecture definition. |
| Training step: 45000 |
| Precision: float32 |
| """ |
|
|
| from triax.models.nn.condUNet import CondUNet2D |
| import jax.numpy as jnp |
| from flax import nnx |
|
|
| |
| model = CondUNet2D( |
| blocks_down=[TimestepEmbedSequential( |
| layers=[Conv( |
| bias=Param( |
| value=Array(shape=(32,), dtype=dtype('float32')) |
| ), |
| bias_init=<function zeros at 0x7fb32b98c2c0>, |
| conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>, |
| dtype=float32, |
| feature_group_count=1, |
| in_features=1, |
| input_dilation=1, |
| kernel=Param( |
| value=Array(shape=(3, 3, 1, 32), dtype=dtype('float32')) |
| ), |
| kernel_dilation=1, |
| kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>, |
| kernel_shape=(3, 3, 1, 64), |
| kernel_size=(3, 3), |
| mask=None, |
| out_features=64, |
| padding=(1, 1), |
| param_dtype=float32, |
| precision=None, |
| promote_dtype=<function promote_dtype at 0x7fb3299d7380>, |
| strides=(1, 1), |
| use_bias=True |
| )] |
| ), TimestepEmbedSequential( |
| layers=[ResnetBlock( |
| activation=<PjitFunction of <function silu at 0x7fb32b8440e0>>, |
| conv1=Conv( |
| bias=Param( |
| value=Array(shape=(32,), dtype=dtype('float32')) |
| ), |
| bias_init=<function zeros at 0x7fb32b98c2c0>, |
| conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>, |
| dtype=float32, |
| feature_group_count=1, |
| in_features=64, |
| input_dilation=1, |
| kernel=Param( |
| value=Array(shape=(3, 3, 32, 32), dtype=dtype('float32')) |
| ), |
| kernel_dilation=1, |
| kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>, |
| kernel_shape=(3, 3, 64, 64), |
| kernel_size=(3, 3), |
| mask=None, |
| out_features=64, |
| padding=(1, 1), |
| param_dtype=float32, |
| precision=None, |
| promote_dtype=<function promote_dtype at 0x7fb3299d7380>, |
| strides=(1, 1), |
| use_bias=True |
| ), |
| conv2=Conv( |
| bias=Param( |
| value=Array(shape=(32,), dtype=dtype('float32')) |
| ), |
| bias_init=<function zeros at 0x7fb32b98c2c0>, |
| conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>, |
| dtype=float32, |
| feature_group_count=1, |
| in_features=64, |
| input_dilation=1, |
| kernel=Param( |
| value=Array(shape=(3, 3, 32, 32), dtype=dtype('float32')) |
| ), |
| kernel_dilation=1, |
| kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>, |
| kernel_shape=(3, 3, 64, 64), |
| kernel_size=(3, 3), |
| mask=None, |
| out_features=64, |
| padding=(1, 1), |
| param_dtype=float32, |
| precision=None, |
| promote_dtype=<function promote_dtype at 0x7fb3299d7380>, |
| strides=(1, 1), |
| use_bias=True |
| ), |
| dropout=Dropout( |
| broadcast_dims=(), |
| deterministic=False, |
| rate=0.1, |
| rng_collection='dropout', |
| rngs=RngStream( |
| count=RngCount( |
| value=Array(0, dtype=uint32), |
| tag='default' |
| ), |
| key=RngKey( |
| value=Array((), dtype=key<fry>) overlaying: |
| [2585633080 2083471411], |
| tag='default' |
| ), |
| tag='default' |
| ) |
| ), |
| embedding_dim=256, |
| in_channels=64, |
| norm1=GroupNorm( |
| axis_index_groups=None, |
| axis_name=None, |
| bias=Param( |
| value=Array(shape=(32,), dtype=dtype('float32')) |
| ), |
| bias_init=<function zeros at 0x7fb32b98c2c0>, |
| dtype=float32, |
| epsilon=1e-06, |
| feature_axis=-1, |
| group_size=2, |
| num_groups=32, |
| param_dtype=float32, |
| reduction_axes=None, |
| scale=Param( |
| value=Array(shape=(32,), dtype=dtype('float32')) |
| ), |
| scale_init=<function ones at 0x7fb32b86e520>, |
| use_bias=True, |
| use_fast_variance=True, |
| use_scale=True |
| ), |
| norm2=GroupNorm( |
| axis_index_groups=None, |
| axis_name=None, |
| bias=Param( |
| value=Array(shape=(32,), dtype=dtype('float32')) |
| ), |
| bias_init=<function zeros at 0x7fb32b98c2c0>, |
| dtype=float32, |
| epsilon=1e-06, |
| feature_axis=-1, |
| group_size=2, |
| num_groups=32, |
| param_dtype=float32, |
| reduction_axes=None, |
| scale=Param( |
| value=Array(shape=(32,), dtype=dtype('float32')) |
| ), |
| scale_init=<function ones at 0x7fb32b86e520>, |
| use_bias=True, |
| use_fast_variance=True, |
| use_scale=True |
| ), |
| out_channels=64, |
| time_mlp=Linear( |
| bias=Param( |
| value=Array(shape=(64,), dtype=dtype('float32')) |
| ), |
| bias_init=<function zeros at 0x7fb32b98c2c0>, |
| dot_general=<function dot_general at 0x7fb32c21a3e0>, |
| dtype=float32, |
| in_features=256, |
| kernel=Param( |
| value=Array(shape=(128, 64), dtype=dtype('float32')) |
| ), |
| kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>, |
| out_features=128, |
| param_dtype=float32, |
| precision=None, |
| promote_dtype=<function promote_dtype at 0x7fb3299d7380>, |
| use_bias=True |
| ) |
| )] |
| ), TimestepEmbedSequential( |
| layers=[Downsample( |
| conv=Conv( |
| bias=Param( |
| value=Array(shape=(32,), dtype=dtype('float32')) |
| ), |
| bias_init=<function zeros at 0x7fb32b98c2c0>, |
| conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>, |
| dtype=float32, |
| feature_group_count=1, |
| in_features=64, |
| input_dilation=1, |
| kernel=Param( |
| value=Array(shape=(3, 3, 32, 32), dtype=dtype('float32')) |
| ), |
| kernel_dilation=1, |
| kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>, |
| kernel_shape=(3, 3, 64, 64), |
| kernel_size=(3, 3), |
| mask=None, |
| out_features=64, |
| padding=(0, 0), |
| param_dtype=float32, |
| precision=None, |
| promote_dtype=<function promote_dtype at 0x7fb3299d7380>, |
| strides=(2, 2), |
| use_bias=True |
| ), |
| method='conv' |
| )] |
| ), TimestepEmbedSequential( |
| layers=[ResnetBlock( |
| activation=<PjitFunction of <function silu at 0x7fb32b8440e0>>, |
| conv1=Conv( |
| bias=Param( |
| value=Array(shape=(64,), dtype=dtype('float32')) |
| ), |
| bias_init=<function zeros at 0x7fb32b98c2c0>, |
| conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>, |
| dtype=float32, |
| feature_group_count=1, |
| in_features=64, |
| input_dilation=1, |
| kernel=Param( |
| value=Array(shape=(3, 3, 32, 64), dtype=dtype('float32')) |
| ), |
| kernel_dilation=1, |
| kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>, |
| kernel_shape=(3, 3, 64, 128), |
| kernel_size=(3, 3), |
| mask=None, |
| out_features=128, |
| padding=(1, 1), |
| param_dtype=float32, |
| precision=None, |
| promote_dtype=<function promote_dtype at 0x7fb3299d7380>, |
| strides=(1, 1), |
| use_bias=True |
| ), |
| conv2=Conv( |
| bias=Param( |
| value=Array(shape=(64,), dtype=dtype('float32')) |
| ), |
| bias_init=<function zeros at 0x7fb32b98c2c0>, |
| conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>, |
| dtype=float32, |
| feature_group_count=1, |
| in_features=128, |
| input_dilation=1, |
| kernel=Param( |
| value=Array(shape=(3, 3, 64, 64), dtype=dtype('float32')) |
| ), |
| kernel_dilation=1, |
| kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>, |
| kernel_shape=(3, 3, 128, 128), |
| kernel_size=(3, 3), |
| mask=None, |
| out_features=128, |
| padding=(1, 1), |
| param_dtype=float32, |
| precision=None, |
| promote_dtype=<function promote_dtype at 0x7fb3299d7380>, |
| strides=(1, 1), |
| use_bias=True |
| ), |
| dropout=Dropout( |
| broadcast_dims=(), |
| deterministic=False, |
| rate=0.1, |
| rng_collection='dropout', |
| rngs=RngStream( |
| count=RngCount( |
| value=Array(0, dtype=uint32), |
| tag='default' |
| ), |
| key=RngKey( |
| value=Array((), dtype=key<fry>) overlaying: |
| [2656139193 2766658851], |
| tag='default' |
| ), |
| tag='default' |
| ) |
| ), |
| embedding_dim=256, |
| in_channels=64, |
| nin_shortcut=Conv( |
| bias=Param( |
| value=Array(shape=(64,), dtype=dtype('float32')) |
| ), |
| bias_init=<function zeros at 0x7fb32b98c2c0>, |
| conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>, |
| dtype=float32, |
| feature_group_count=1, |
| in_features=64, |
| input_dilation=1, |
| kernel=Param( |
| value=Array(shape=(1, 1, 32, 64), dtype=dtype('float32')) |
| ), |
| kernel_dilation=1, |
| kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>, |
| kernel_shape=(1, 1, 64, 128), |
| kernel_size=(1, 1), |
| mask=None, |
| out_features=128, |
| padding=(0, 0), |
| param_dtype=float32, |
| precision=None, |
| promote_dtype=<function promote_dtype at 0x7fb3299d7380>, |
| strides=(1, 1), |
| use_bias=True |
| ), |
| norm1=GroupNorm( |
| axis_index_groups=None, |
| axis_name=None, |
| bias=Param( |
| value=Array(shape=(32,), dtype=dtype('float32')) |
| ), |
| bias_init=<function zeros at 0x7fb32b98c2c0>, |
| dtype=float32, |
| epsilon=1e-06, |
| feature_axis=-1, |
| group_size=2, |
| num_groups=32, |
| param_dtype=float32, |
| reduction_axes=None, |
| scale=Param( |
| value=Array(shape=(32,), dtype=dtype('float32')) |
| ), |
| scale_init=<function ones at 0x7fb32b86e520>, |
| use_bias=True, |
| use_fast_variance=True, |
| use_scale=True |
| ), |
| norm2=GroupNorm( |
| axis_index_groups=None, |
| axis_name=None, |
| bias=Param( |
| value=Array(shape=(64,), dtype=dtype('float32')) |
| ), |
| bias_init=<function zeros at 0x7fb32b98c2c0>, |
| dtype=float32, |
| epsilon=1e-06, |
| feature_axis=-1, |
| group_size=4, |
| num_groups=32, |
| param_dtype=float32, |
| reduction_axes=None, |
| scale=Param( |
| value=Array(shape=(64,), dtype=dtype('float32')) |
| ), |
| scale_init=<function ones at 0x7fb32b86e520>, |
| use_bias=True, |
| use_fast_variance=True, |
| use_scale=True |
| ), |
| out_channels=128, |
| time_mlp=Linear( |
| bias=Param( |
| value=Array(shape=(128,), dtype=dtype('float32')) |
| ), |
| bias_init=<function zeros at 0x7fb32b98c2c0>, |
| dot_general=<function dot_general at 0x7fb32c21a3e0>, |
| dtype=float32, |
| in_features=256, |
| kernel=Param( |
| value=Array(shape=(128, 128), dtype=dtype('float32')) |
| ), |
| kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>, |
| out_features=256, |
| param_dtype=float32, |
| precision=None, |
| promote_dtype=<function promote_dtype at 0x7fb3299d7380>, |
| use_bias=True |
| ) |
| ), AttnBlock( |
| dtype=float32, |
| head_dim=32, |
| k=Conv( |
| bias=Param( |
| value=Array(shape=(64,), dtype=dtype('float32')) |
| ), |
| bias_init=<function zeros at 0x7fb32b98c2c0>, |
| conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>, |
| dtype=float32, |
| feature_group_count=1, |
| in_features=128, |
| input_dilation=1, |
| kernel=Param( |
| value=Array(shape=(1, 1, 64, 64), dtype=dtype('float32')) |
| ), |
| kernel_dilation=1, |
| kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>, |
| kernel_shape=(1, 1, 128, 128), |
| kernel_size=(1, 1), |
| mask=None, |
| out_features=128, |
| padding='SAME', |
| param_dtype=float32, |
| precision=None, |
| promote_dtype=<function promote_dtype at 0x7fb3299d7380>, |
| strides=1, |
| use_bias=True |
| ), |
| norm=GroupNorm( |
| axis_index_groups=None, |
| axis_name=None, |
| bias=Param( |
| value=Array(shape=(64,), dtype=dtype('float32')) |
| ), |
| bias_init=<function zeros at 0x7fb32b98c2c0>, |
| dtype=float32, |
| epsilon=1e-06, |
| feature_axis=-1, |
| group_size=4, |
| num_groups=32, |
| param_dtype=float32, |
| reduction_axes=None, |
| scale=Param( |
| value=Array(shape=(64,), dtype=dtype('float32')) |
| ), |
| scale_init=<function ones at 0x7fb32b86e520>, |
| use_bias=True, |
| use_fast_variance=True, |
| use_scale=True |
| ), |
| num_heads=4, |
| proj_out=Conv( |
| bias=Param( |
| value=Array(shape=(64,), dtype=dtype('float32')) |
| ), |
| bias_init=<function zeros at 0x7fb32b98c2c0>, |
| conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>, |
| dtype=float32, |
| feature_group_count=1, |
| in_features=128, |
| input_dilation=1, |
| kernel=Param( |
| value=Array(shape=(1, 1, 64, 64), dtype=dtype('float32')) |
| ), |
| kernel_dilation=1, |
| kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>, |
| kernel_shape=(1, 1, 128, 128), |
| kernel_size=(1, 1), |
| mask=None, |
| out_features=128, |
| padding='SAME', |
| param_dtype=float32, |
| precision=None, |
| promote_dtype=<function promote_dtype at 0x7fb3299d7380>, |
| strides=1, |
| use_bias=True |
| ), |
| q=Conv( |
| bias=Param( |
| value=Array(shape=(64,), dtype=dtype('float32')) |
| ), |
| bias_init=<function zeros at 0x7fb32b98c2c0>, |
| conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>, |
| dtype=float32, |
| feature_group_count=1, |
| in_features=128, |
| input_dilation=1, |
| kernel=Param( |
| value=Array(shape=(1, 1, 64, 64), dtype=dtype('float32')) |
| ), |
| kernel_dilation=1, |
| kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>, |
| kernel_shape=(1, 1, 128, 128), |
| kernel_size=(1, 1), |
| mask=None, |
| out_features=128, |
| padding='SAME', |
| param_dtype=float32, |
| precision=None, |
| promote_dtype=<function promote_dtype at 0x7fb3299d7380>, |
| strides=1, |
| use_bias=True |
| ), |
| v=Conv( |
| bias=Param( |
| value=Array(shape=(64,), dtype=dtype('float32')) |
| ), |
| bias_init=<function zeros at 0x7fb32b98c2c0>, |
| conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>, |
| dtype=float32, |
| feature_group_count=1, |
| in_features=128, |
| input_dilation=1, |
| kernel=Param( |
| value=Array(shape=(1, 1, 64, 64), dtype=dtype('float32')) |
| ), |
| kernel_dilation=1, |
| kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>, |
| kernel_shape=(1, 1, 128, 128), |
| kernel_size=(1, 1), |
| mask=None, |
| out_features=128, |
| padding='SAME', |
| param_dtype=float32, |
| precision=None, |
| promote_dtype=<function promote_dtype at 0x7fb3299d7380>, |
| strides=1, |
| use_bias=True |
| ) |
| )] |
| ), TimestepEmbedSequential( |
| layers=[Downsample( |
| conv=Conv( |
| bias=Param( |
| value=Array(shape=(64,), dtype=dtype('float32')) |
| ), |
| bias_init=<function zeros at 0x7fb32b98c2c0>, |
| conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>, |
| dtype=float32, |
| feature_group_count=1, |
| in_features=128, |
| input_dilation=1, |
| kernel=Param( |
| value=Array(shape=(3, 3, 64, 64), dtype=dtype('float32')) |
| ), |
| kernel_dilation=1, |
| kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>, |
| kernel_shape=(3, 3, 128, 128), |
| kernel_size=(3, 3), |
| mask=None, |
| out_features=128, |
| padding=(0, 0), |
| param_dtype=float32, |
| precision=None, |
| promote_dtype=<function promote_dtype at 0x7fb3299d7380>, |
| strides=(2, 2), |
| use_bias=True |
| ), |
| method='conv' |
| )] |
| ), TimestepEmbedSequential( |
| layers=[ResnetBlock( |
| activation=<PjitFunction of <function silu at 0x7fb32b8440e0>>, |
| conv1=Conv( |
| bias=Param( |
| value=Array(shape=(64,), dtype=dtype('float32')) |
| ), |
| bias_init=<function zeros at 0x7fb32b98c2c0>, |
| conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>, |
| dtype=float32, |
| feature_group_count=1, |
| in_features=128, |
| input_dilation=1, |
| kernel=Param( |
| value=Array(shape=(3, 3, 64, 64), dtype=dtype('float32')) |
| ), |
| kernel_dilation=1, |
| kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>, |
| kernel_shape=(3, 3, 128, 128), |
| kernel_size=(3, 3), |
| mask=None, |
| out_features=128, |
| padding=(1, 1), |
| param_dtype=float32, |
| precision=None, |
| promote_dtype=<function promote_dtype at 0x7fb3299d7380>, |
| strides=(1, 1), |
| use_bias=True |
| ), |
| conv2=Conv( |
| bias=Param( |
| value=Array(shape=(64,), dtype=dtype('float32')) |
| ), |
| bias_init=<function zeros at 0x7fb32b98c2c0>, |
| conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>, |
| dtype=float32, |
| feature_group_count=1, |
| in_features=128, |
| input_dilation=1, |
| kernel=Param( |
| value=Array(shape=(3, 3, 64, 64), dtype=dtype('float32')) |
| ), |
| kernel_dilation=1, |
| kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>, |
| kernel_shape=(3, 3, 128, 128), |
| kernel_size=(3, 3), |
| mask=None, |
| out_features=128, |
| padding=(1, 1), |
| param_dtype=float32, |
| precision=None, |
| promote_dtype=<function promote_dtype at 0x7fb3299d7380>, |
| strides=(1, 1), |
| use_bias=True |
| ), |
| dropout=Dropout( |
| broadcast_dims=(), |
| deterministic=False, |
| rate=0.1, |
| rng_collection='dropout', |
| rngs=RngStream( |
| count=RngCount( |
| value=Array(0, dtype=uint32), |
| tag='default' |
| ), |
| key=RngKey( |
| value=Array((), dtype=key<fry>) overlaying: |
| [ 692128146 1043829861], |
| tag='default' |
| ), |
| tag='default' |
| ) |
| ), |
| embedding_dim=256, |
| in_channels=128, |
| norm1=GroupNorm( |
| axis_index_groups=None, |
| axis_name=None, |
| bias=Param( |
| value=Array(shape=(64,), dtype=dtype('float32')) |
| ), |
| bias_init=<function zeros at 0x7fb32b98c2c0>, |
| dtype=float32, |
| epsilon=1e-06, |
| feature_axis=-1, |
| group_size=4, |
| num_groups=32, |
| param_dtype=float32, |
| reduction_axes=None, |
| scale=Param( |
| value=Array(shape=(64,), dtype=dtype('float32')) |
| ), |
| scale_init=<function ones at 0x7fb32b86e520>, |
| use_bias=True, |
| use_fast_variance=True, |
| use_scale=True |
| ), |
| norm2=GroupNorm( |
| axis_index_groups=None, |
| axis_name=None, |
| bias=Param( |
| value=Array(shape=(64,), dtype=dtype('float32')) |
| ), |
| bias_init=<function zeros at 0x7fb32b98c2c0>, |
| dtype=float32, |
| epsilon=1e-06, |
| feature_axis=-1, |
| group_size=4, |
| num_groups=32, |
| param_dtype=float32, |
| reduction_axes=None, |
| scale=Param( |
| value=Array(shape=(64,), dtype=dtype('float32')) |
| ), |
| scale_init=<function ones at 0x7fb32b86e520>, |
| use_bias=True, |
| use_fast_variance=True, |
| use_scale=True |
| ), |
| out_channels=128, |
| time_mlp=Linear( |
| bias=Param( |
| value=Array(shape=(128,), dtype=dtype('float32')) |
| ), |
| bias_init=<function zeros at 0x7fb32b98c2c0>, |
| dot_general=<function dot_general at 0x7fb32c21a3e0>, |
| dtype=float32, |
| in_features=256, |
| kernel=Param( |
| value=Array(shape=(128, 128), dtype=dtype('float32')) |
| ), |
| kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>, |
| out_features=256, |
| param_dtype=float32, |
| precision=None, |
| promote_dtype=<function promote_dtype at 0x7fb3299d7380>, |
| use_bias=True |
| ) |
| )] |
| )], |
| blocks_up=[TimestepEmbedSequential( |
| layers=[ResnetBlock( |
| activation=<PjitFunction of <function silu at 0x7fb32b8440e0>>, |
| conv1=Conv( |
| bias=Param( |
| value=Array(shape=(64,), dtype=dtype('float32')) |
| ), |
| bias_init=<function zeros at 0x7fb32b98c2c0>, |
| conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>, |
| dtype=float32, |
| feature_group_count=1, |
| in_features=256, |
| input_dilation=1, |
| kernel=Param( |
| value=Array(shape=(3, 3, 128, 64), dtype=dtype('float32')) |
| ), |
| kernel_dilation=1, |
| kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>, |
| kernel_shape=(3, 3, 256, 128), |
| kernel_size=(3, 3), |
| mask=None, |
| out_features=128, |
| padding=(1, 1), |
| param_dtype=float32, |
| precision=None, |
| promote_dtype=<function promote_dtype at 0x7fb3299d7380>, |
| strides=(1, 1), |
| use_bias=True |
| ), |
| conv2=Conv( |
| bias=Param( |
| value=Array(shape=(64,), dtype=dtype('float32')) |
| ), |
| bias_init=<function zeros at 0x7fb32b98c2c0>, |
| conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>, |
| dtype=float32, |
| feature_group_count=1, |
| in_features=128, |
| input_dilation=1, |
| kernel=Param( |
| value=Array(shape=(3, 3, 64, 64), dtype=dtype('float32')) |
| ), |
| kernel_dilation=1, |
| kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>, |
| kernel_shape=(3, 3, 128, 128), |
| kernel_size=(3, 3), |
| mask=None, |
| out_features=128, |
| padding=(1, 1), |
| param_dtype=float32, |
| precision=None, |
| promote_dtype=<function promote_dtype at 0x7fb3299d7380>, |
| strides=(1, 1), |
| use_bias=True |
| ), |
| dropout=Dropout( |
| broadcast_dims=(), |
| deterministic=False, |
| rate=0.1, |
| rng_collection='dropout', |
| rngs=RngStream( |
| count=RngCount( |
| value=Array(0, dtype=uint32), |
| tag='default' |
| ), |
| key=RngKey( |
| value=Array((), dtype=key<fry>) overlaying: |
| [2853902436 2217684095], |
| tag='default' |
| ), |
| tag='default' |
| ) |
| ), |
| embedding_dim=256, |
| in_channels=256, |
| nin_shortcut=Conv( |
| bias=Param( |
| value=Array(shape=(64,), dtype=dtype('float32')) |
| ), |
| bias_init=<function zeros at 0x7fb32b98c2c0>, |
| conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>, |
| dtype=float32, |
| feature_group_count=1, |
| in_features=256, |
| input_dilation=1, |
| kernel=Param( |
| value=Array(shape=(1, 1, 128, 64), dtype=dtype('float32')) |
| ), |
| kernel_dilation=1, |
| kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>, |
| kernel_shape=(1, 1, 256, 128), |
| kernel_size=(1, 1), |
| mask=None, |
| out_features=128, |
| padding=(0, 0), |
| param_dtype=float32, |
| precision=None, |
| promote_dtype=<function promote_dtype at 0x7fb3299d7380>, |
| strides=(1, 1), |
| use_bias=True |
| ), |
| norm1=GroupNorm( |
| axis_index_groups=None, |
| axis_name=None, |
| bias=Param( |
| value=Array(shape=(128,), dtype=dtype('float32')) |
| ), |
| bias_init=<function zeros at 0x7fb32b98c2c0>, |
| dtype=float32, |
| epsilon=1e-06, |
| feature_axis=-1, |
| group_size=8, |
| num_groups=32, |
| param_dtype=float32, |
| reduction_axes=None, |
| scale=Param( |
| value=Array(shape=(128,), dtype=dtype('float32')) |
| ), |
| scale_init=<function ones at 0x7fb32b86e520>, |
| use_bias=True, |
| use_fast_variance=True, |
| use_scale=True |
| ), |
| norm2=GroupNorm( |
| axis_index_groups=None, |
| axis_name=None, |
| bias=Param( |
| value=Array(shape=(64,), dtype=dtype('float32')) |
| ), |
| bias_init=<function zeros at 0x7fb32b98c2c0>, |
| dtype=float32, |
| epsilon=1e-06, |
| feature_axis=-1, |
| group_size=4, |
| num_groups=32, |
| param_dtype=float32, |
| reduction_axes=None, |
| scale=Param( |
| value=Array(shape=(64,), dtype=dtype('float32')) |
| ), |
| scale_init=<function ones at 0x7fb32b86e520>, |
| use_bias=True, |
| use_fast_variance=True, |
| use_scale=True |
| ), |
| out_channels=128, |
| time_mlp=Linear( |
| bias=Param( |
| value=Array(shape=(128,), dtype=dtype('float32')) |
| ), |
| bias_init=<function zeros at 0x7fb32b98c2c0>, |
| dot_general=<function dot_general at 0x7fb32c21a3e0>, |
| dtype=float32, |
| in_features=256, |
| kernel=Param( |
| value=Array(shape=(128, 128), dtype=dtype('float32')) |
| ), |
| kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>, |
| out_features=256, |
| param_dtype=float32, |
| precision=None, |
| promote_dtype=<function promote_dtype at 0x7fb3299d7380>, |
| use_bias=True |
| ) |
| )] |
| ), TimestepEmbedSequential( |
| layers=[ResnetBlock( |
| activation=<PjitFunction of <function silu at 0x7fb32b8440e0>>, |
| conv1=Conv( |
| bias=Param( |
| value=Array(shape=(64,), dtype=dtype('float32')) |
| ), |
| bias_init=<function zeros at 0x7fb32b98c2c0>, |
| conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>, |
| dtype=float32, |
| feature_group_count=1, |
| in_features=256, |
| input_dilation=1, |
| kernel=Param( |
| value=Array(shape=(3, 3, 128, 64), dtype=dtype('float32')) |
| ), |
| kernel_dilation=1, |
| kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>, |
| kernel_shape=(3, 3, 256, 128), |
| kernel_size=(3, 3), |
| mask=None, |
| out_features=128, |
| padding=(1, 1), |
| param_dtype=float32, |
| precision=None, |
| promote_dtype=<function promote_dtype at 0x7fb3299d7380>, |
| strides=(1, 1), |
| use_bias=True |
| ), |
| conv2=Conv( |
| bias=Param( |
| value=Array(shape=(64,), dtype=dtype('float32')) |
| ), |
| bias_init=<function zeros at 0x7fb32b98c2c0>, |
| conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>, |
| dtype=float32, |
| feature_group_count=1, |
| in_features=128, |
| input_dilation=1, |
| kernel=Param( |
| value=Array(shape=(3, 3, 64, 64), dtype=dtype('float32')) |
| ), |
| kernel_dilation=1, |
| kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>, |
| kernel_shape=(3, 3, 128, 128), |
| kernel_size=(3, 3), |
| mask=None, |
| out_features=128, |
| padding=(1, 1), |
| param_dtype=float32, |
| precision=None, |
| promote_dtype=<function promote_dtype at 0x7fb3299d7380>, |
| strides=(1, 1), |
| use_bias=True |
| ), |
| dropout=Dropout( |
| broadcast_dims=(), |
| deterministic=False, |
| rate=0.1, |
| rng_collection='dropout', |
| rngs=RngStream( |
| count=RngCount( |
| value=Array(0, dtype=uint32), |
| tag='default' |
| ), |
| key=RngKey( |
| value=Array((), dtype=key<fry>) overlaying: |
| [2785098898 841100811], |
| tag='default' |
| ), |
| tag='default' |
| ) |
| ), |
| embedding_dim=256, |
| in_channels=256, |
| nin_shortcut=Conv( |
| bias=Param( |
| value=Array(shape=(64,), dtype=dtype('float32')) |
| ), |
| bias_init=<function zeros at 0x7fb32b98c2c0>, |
| conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>, |
| dtype=float32, |
| feature_group_count=1, |
| in_features=256, |
| input_dilation=1, |
| kernel=Param( |
| value=Array(shape=(1, 1, 128, 64), dtype=dtype('float32')) |
| ), |
| kernel_dilation=1, |
| kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>, |
| kernel_shape=(1, 1, 256, 128), |
| kernel_size=(1, 1), |
| mask=None, |
| out_features=128, |
| padding=(0, 0), |
| param_dtype=float32, |
| precision=None, |
| promote_dtype=<function promote_dtype at 0x7fb3299d7380>, |
| strides=(1, 1), |
| use_bias=True |
| ), |
| norm1=GroupNorm( |
| axis_index_groups=None, |
| axis_name=None, |
| bias=Param( |
| value=Array(shape=(128,), dtype=dtype('float32')) |
| ), |
| bias_init=<function zeros at 0x7fb32b98c2c0>, |
| dtype=float32, |
| epsilon=1e-06, |
| feature_axis=-1, |
| group_size=8, |
| num_groups=32, |
| param_dtype=float32, |
| reduction_axes=None, |
| scale=Param( |
| value=Array(shape=(128,), dtype=dtype('float32')) |
| ), |
| scale_init=<function ones at 0x7fb32b86e520>, |
| use_bias=True, |
| use_fast_variance=True, |
| use_scale=True |
| ), |
| norm2=GroupNorm( |
| axis_index_groups=None, |
| axis_name=None, |
| bias=Param( |
| value=Array(shape=(64,), dtype=dtype('float32')) |
| ), |
| bias_init=<function zeros at 0x7fb32b98c2c0>, |
| dtype=float32, |
| epsilon=1e-06, |
| feature_axis=-1, |
| group_size=4, |
| num_groups=32, |
| param_dtype=float32, |
| reduction_axes=None, |
| scale=Param( |
| value=Array(shape=(64,), dtype=dtype('float32')) |
| ), |
| scale_init=<function ones at 0x7fb32b86e520>, |
| use_bias=True, |
| use_fast_variance=True, |
| use_scale=True |
| ), |
| out_channels=128, |
| time_mlp=Linear( |
| bias=Param( |
| value=Array(shape=(128,), dtype=dtype('float32')) |
| ), |
| bias_init=<function zeros at 0x7fb32b98c2c0>, |
| dot_general=<function dot_general at 0x7fb32c21a3e0>, |
| dtype=float32, |
| in_features=256, |
| kernel=Param( |
| value=Array(shape=(128, 128), dtype=dtype('float32')) |
| ), |
| kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>, |
| out_features=256, |
| param_dtype=float32, |
| precision=None, |
| promote_dtype=<function promote_dtype at 0x7fb3299d7380>, |
| use_bias=True |
| ) |
| ), Upsample( |
| conv_pixel_shuffle=Conv( |
| bias=Param( |
| value=Array(shape=(256,), dtype=dtype('float32')) |
| ), |
| bias_init=<function zeros at 0x7fb32b98c2c0>, |
| conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>, |
| dtype=float32, |
| feature_group_count=1, |
| in_features=128, |
| input_dilation=1, |
| kernel=Param( |
| value=Array(shape=(3, 3, 64, 256), dtype=dtype('float32')) |
| ), |
| kernel_dilation=1, |
| kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>, |
| kernel_shape=(3, 3, 128, 512), |
| kernel_size=(3, 3), |
| mask=None, |
| out_features=512, |
| padding=(1, 1), |
| param_dtype=float32, |
| precision=None, |
| promote_dtype=<function promote_dtype at 0x7fb3299d7380>, |
| strides=(1, 1), |
| use_bias=True |
| ), |
| conv_resize=Conv( |
| bias=Param( |
| value=Array(shape=(64,), dtype=dtype('float32')) |
| ), |
| bias_init=<function zeros at 0x7fb32b98c2c0>, |
| conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>, |
| dtype=float32, |
| feature_group_count=1, |
| in_features=128, |
| input_dilation=1, |
| kernel=Param( |
| value=Array(shape=(3, 3, 64, 64), dtype=dtype('float32')) |
| ), |
| kernel_dilation=1, |
| kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>, |
| kernel_shape=(3, 3, 128, 128), |
| kernel_size=(3, 3), |
| mask=None, |
| out_features=128, |
| padding=(1, 1), |
| param_dtype=float32, |
| precision=None, |
| promote_dtype=<function promote_dtype at 0x7fb3299d7380>, |
| strides=(1, 1), |
| use_bias=True |
| ), |
| method='pixel_shuffle', |
| pixel_shuffle=PixelShuffle( |
| scale=2 |
| ), |
| scale_factor=2 |
| )] |
| ), TimestepEmbedSequential( |
| layers=[ResnetBlock( |
| activation=<PjitFunction of <function silu at 0x7fb32b8440e0>>, |
| conv1=Conv( |
| bias=Param( |
| value=Array(shape=(64,), dtype=dtype('float32')) |
| ), |
| bias_init=<function zeros at 0x7fb32b98c2c0>, |
| conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>, |
| dtype=float32, |
| feature_group_count=1, |
| in_features=256, |
| input_dilation=1, |
| kernel=Param( |
| value=Array(shape=(3, 3, 128, 64), dtype=dtype('float32')) |
| ), |
| kernel_dilation=1, |
| kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>, |
| kernel_shape=(3, 3, 256, 128), |
| kernel_size=(3, 3), |
| mask=None, |
| out_features=128, |
| padding=(1, 1), |
| param_dtype=float32, |
| precision=None, |
| promote_dtype=<function promote_dtype at 0x7fb3299d7380>, |
| strides=(1, 1), |
| use_bias=True |
| ), |
| conv2=Conv( |
| bias=Param( |
| value=Array(shape=(64,), dtype=dtype('float32')) |
| ), |
| bias_init=<function zeros at 0x7fb32b98c2c0>, |
| conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>, |
| dtype=float32, |
| feature_group_count=1, |
| in_features=128, |
| input_dilation=1, |
| kernel=Param( |
| value=Array(shape=(3, 3, 64, 64), dtype=dtype('float32')) |
| ), |
| kernel_dilation=1, |
| kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>, |
| kernel_shape=(3, 3, 128, 128), |
| kernel_size=(3, 3), |
| mask=None, |
| out_features=128, |
| padding=(1, 1), |
| param_dtype=float32, |
| precision=None, |
| promote_dtype=<function promote_dtype at 0x7fb3299d7380>, |
| strides=(1, 1), |
| use_bias=True |
| ), |
| dropout=Dropout( |
| broadcast_dims=(), |
| deterministic=False, |
| rate=0.1, |
| rng_collection='dropout', |
| rngs=RngStream( |
| count=RngCount( |
| value=Array(0, dtype=uint32), |
| tag='default' |
| ), |
| key=RngKey( |
| value=Array((), dtype=key<fry>) overlaying: |
| [ 48802331 1548237274], |
| tag='default' |
| ), |
| tag='default' |
| ) |
| ), |
| embedding_dim=256, |
| in_channels=256, |
| nin_shortcut=Conv( |
| bias=Param( |
| value=Array(shape=(64,), dtype=dtype('float32')) |
| ), |
| bias_init=<function zeros at 0x7fb32b98c2c0>, |
| conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>, |
| dtype=float32, |
| feature_group_count=1, |
| in_features=256, |
| input_dilation=1, |
| kernel=Param( |
| value=Array(shape=(1, 1, 128, 64), dtype=dtype('float32')) |
| ), |
| kernel_dilation=1, |
| kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>, |
| kernel_shape=(1, 1, 256, 128), |
| kernel_size=(1, 1), |
| mask=None, |
| out_features=128, |
| padding=(0, 0), |
| param_dtype=float32, |
| precision=None, |
| promote_dtype=<function promote_dtype at 0x7fb3299d7380>, |
| strides=(1, 1), |
| use_bias=True |
| ), |
| norm1=GroupNorm( |
| axis_index_groups=None, |
| axis_name=None, |
| bias=Param( |
| value=Array(shape=(128,), dtype=dtype('float32')) |
| ), |
| bias_init=<function zeros at 0x7fb32b98c2c0>, |
| dtype=float32, |
| epsilon=1e-06, |
| feature_axis=-1, |
| group_size=8, |
| num_groups=32, |
| param_dtype=float32, |
| reduction_axes=None, |
| scale=Param( |
| value=Array(shape=(128,), dtype=dtype('float32')) |
| ), |
| scale_init=<function ones at 0x7fb32b86e520>, |
| use_bias=True, |
| use_fast_variance=True, |
| use_scale=True |
| ), |
| norm2=GroupNorm( |
| axis_index_groups=None, |
| axis_name=None, |
| bias=Param( |
| value=Array(shape=(64,), dtype=dtype('float32')) |
| ), |
| bias_init=<function zeros at 0x7fb32b98c2c0>, |
| dtype=float32, |
| epsilon=1e-06, |
| feature_axis=-1, |
| group_size=4, |
| num_groups=32, |
| param_dtype=float32, |
| reduction_axes=None, |
| scale=Param( |
| value=Array(shape=(64,), dtype=dtype('float32')) |
| ), |
| scale_init=<function ones at 0x7fb32b86e520>, |
| use_bias=True, |
| use_fast_variance=True, |
| use_scale=True |
| ), |
| out_channels=128, |
| time_mlp=Linear( |
| bias=Param( |
| value=Array(shape=(128,), dtype=dtype('float32')) |
| ), |
| bias_init=<function zeros at 0x7fb32b98c2c0>, |
| dot_general=<function dot_general at 0x7fb32c21a3e0>, |
| dtype=float32, |
| in_features=256, |
| kernel=Param( |
| value=Array(shape=(128, 128), dtype=dtype('float32')) |
| ), |
| kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>, |
| out_features=256, |
| param_dtype=float32, |
| precision=None, |
| promote_dtype=<function promote_dtype at 0x7fb3299d7380>, |
| use_bias=True |
| ) |
| ), AttnBlock( |
| dtype=float32, |
| head_dim=32, |
| k=Conv( |
| bias=Param( |
| value=Array(shape=(64,), dtype=dtype('float32')) |
| ), |
| bias_init=<function zeros at 0x7fb32b98c2c0>, |
| conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>, |
| dtype=float32, |
| feature_group_count=1, |
| in_features=128, |
| input_dilation=1, |
| kernel=Param( |
| value=Array(shape=(1, 1, 64, 64), dtype=dtype('float32')) |
| ), |
| kernel_dilation=1, |
| kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>, |
| kernel_shape=(1, 1, 128, 128), |
| kernel_size=(1, 1), |
| mask=None, |
| out_features=128, |
| padding='SAME', |
| param_dtype=float32, |
| precision=None, |
| promote_dtype=<function promote_dtype at 0x7fb3299d7380>, |
| strides=1, |
| use_bias=True |
| ), |
| norm=GroupNorm( |
| axis_index_groups=None, |
| axis_name=None, |
| bias=Param( |
| value=Array(shape=(64,), dtype=dtype('float32')) |
| ), |
| bias_init=<function zeros at 0x7fb32b98c2c0>, |
| dtype=float32, |
| epsilon=1e-06, |
| feature_axis=-1, |
| group_size=4, |
| num_groups=32, |
| param_dtype=float32, |
| reduction_axes=None, |
| scale=Param( |
| value=Array(shape=(64,), dtype=dtype('float32')) |
| ), |
| scale_init=<function ones at 0x7fb32b86e520>, |
| use_bias=True, |
| use_fast_variance=True, |
| use_scale=True |
| ), |
| num_heads=4, |
| proj_out=Conv( |
| bias=Param( |
| value=Array(shape=(64,), dtype=dtype('float32')) |
| ), |
| bias_init=<function zeros at 0x7fb32b98c2c0>, |
| conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>, |
| dtype=float32, |
| feature_group_count=1, |
| in_features=128, |
| input_dilation=1, |
| kernel=Param( |
| value=Array(shape=(1, 1, 64, 64), dtype=dtype('float32')) |
| ), |
| kernel_dilation=1, |
| kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>, |
| kernel_shape=(1, 1, 128, 128), |
| kernel_size=(1, 1), |
| mask=None, |
| out_features=128, |
| padding='SAME', |
| param_dtype=float32, |
| precision=None, |
| promote_dtype=<function promote_dtype at 0x7fb3299d7380>, |
| strides=1, |
| use_bias=True |
| ), |
| q=Conv( |
| bias=Param( |
| value=Array(shape=(64,), dtype=dtype('float32')) |
| ), |
| bias_init=<function zeros at 0x7fb32b98c2c0>, |
| conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>, |
| dtype=float32, |
| feature_group_count=1, |
| in_features=128, |
| input_dilation=1, |
| kernel=Param( |
| value=Array(shape=(1, 1, 64, 64), dtype=dtype('float32')) |
| ), |
| kernel_dilation=1, |
| kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>, |
| kernel_shape=(1, 1, 128, 128), |
| kernel_size=(1, 1), |
| mask=None, |
| out_features=128, |
| padding='SAME', |
| param_dtype=float32, |
| precision=None, |
| promote_dtype=<function promote_dtype at 0x7fb3299d7380>, |
| strides=1, |
| use_bias=True |
| ), |
| v=Conv( |
| bias=Param( |
| value=Array(shape=(64,), dtype=dtype('float32')) |
| ), |
| bias_init=<function zeros at 0x7fb32b98c2c0>, |
| conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>, |
| dtype=float32, |
| feature_group_count=1, |
| in_features=128, |
| input_dilation=1, |
| kernel=Param( |
| value=Array(shape=(1, 1, 64, 64), dtype=dtype('float32')) |
| ), |
| kernel_dilation=1, |
| kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>, |
| kernel_shape=(1, 1, 128, 128), |
| kernel_size=(1, 1), |
| mask=None, |
| out_features=128, |
| padding='SAME', |
| param_dtype=float32, |
| precision=None, |
| promote_dtype=<function promote_dtype at 0x7fb3299d7380>, |
| strides=1, |
| use_bias=True |
| ) |
| )] |
| ), TimestepEmbedSequential( |
| layers=[ResnetBlock( |
| activation=<PjitFunction of <function silu at 0x7fb32b8440e0>>, |
| conv1=Conv( |
| bias=Param( |
| value=Array(shape=(64,), dtype=dtype('float32')) |
| ), |
| bias_init=<function zeros at 0x7fb32b98c2c0>, |
| conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>, |
| dtype=float32, |
| feature_group_count=1, |
| in_features=192, |
| input_dilation=1, |
| kernel=Param( |
| value=Array(shape=(3, 3, 96, 64), dtype=dtype('float32')) |
| ), |
| kernel_dilation=1, |
| kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>, |
| kernel_shape=(3, 3, 192, 128), |
| kernel_size=(3, 3), |
| mask=None, |
| out_features=128, |
| padding=(1, 1), |
| param_dtype=float32, |
| precision=None, |
| promote_dtype=<function promote_dtype at 0x7fb3299d7380>, |
| strides=(1, 1), |
| use_bias=True |
| ), |
| conv2=Conv( |
| bias=Param( |
| value=Array(shape=(64,), dtype=dtype('float32')) |
| ), |
| bias_init=<function zeros at 0x7fb32b98c2c0>, |
| conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>, |
| dtype=float32, |
| feature_group_count=1, |
| in_features=128, |
| input_dilation=1, |
| kernel=Param( |
| value=Array(shape=(3, 3, 64, 64), dtype=dtype('float32')) |
| ), |
| kernel_dilation=1, |
| kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>, |
| kernel_shape=(3, 3, 128, 128), |
| kernel_size=(3, 3), |
| mask=None, |
| out_features=128, |
| padding=(1, 1), |
| param_dtype=float32, |
| precision=None, |
| promote_dtype=<function promote_dtype at 0x7fb3299d7380>, |
| strides=(1, 1), |
| use_bias=True |
| ), |
| dropout=Dropout( |
| broadcast_dims=(), |
| deterministic=False, |
| rate=0.1, |
| rng_collection='dropout', |
| rngs=RngStream( |
| count=RngCount( |
| value=Array(0, dtype=uint32), |
| tag='default' |
| ), |
| key=RngKey( |
| value=Array((), dtype=key<fry>) overlaying: |
| [1596966061 1315822572], |
| tag='default' |
| ), |
| tag='default' |
| ) |
| ), |
| embedding_dim=256, |
| in_channels=192, |
| nin_shortcut=Conv( |
| bias=Param( |
| value=Array(shape=(64,), dtype=dtype('float32')) |
| ), |
| bias_init=<function zeros at 0x7fb32b98c2c0>, |
| conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>, |
| dtype=float32, |
| feature_group_count=1, |
| in_features=192, |
| input_dilation=1, |
| kernel=Param( |
| value=Array(shape=(1, 1, 96, 64), dtype=dtype('float32')) |
| ), |
| kernel_dilation=1, |
| kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>, |
| kernel_shape=(1, 1, 192, 128), |
| kernel_size=(1, 1), |
| mask=None, |
| out_features=128, |
| padding=(0, 0), |
| param_dtype=float32, |
| precision=None, |
| promote_dtype=<function promote_dtype at 0x7fb3299d7380>, |
| strides=(1, 1), |
| use_bias=True |
| ), |
| norm1=GroupNorm( |
| axis_index_groups=None, |
| axis_name=None, |
| bias=Param( |
| value=Array(shape=(96,), dtype=dtype('float32')) |
| ), |
| bias_init=<function zeros at 0x7fb32b98c2c0>, |
| dtype=float32, |
| epsilon=1e-06, |
| feature_axis=-1, |
| group_size=6, |
| num_groups=32, |
| param_dtype=float32, |
| reduction_axes=None, |
| scale=Param( |
| value=Array(shape=(96,), dtype=dtype('float32')) |
| ), |
| scale_init=<function ones at 0x7fb32b86e520>, |
| use_bias=True, |
| use_fast_variance=True, |
| use_scale=True |
| ), |
| norm2=GroupNorm( |
| axis_index_groups=None, |
| axis_name=None, |
| bias=Param( |
| value=Array(shape=(64,), dtype=dtype('float32')) |
| ), |
| bias_init=<function zeros at 0x7fb32b98c2c0>, |
| dtype=float32, |
| epsilon=1e-06, |
| feature_axis=-1, |
| group_size=4, |
| num_groups=32, |
| param_dtype=float32, |
| reduction_axes=None, |
| scale=Param( |
| value=Array(shape=(64,), dtype=dtype('float32')) |
| ), |
| scale_init=<function ones at 0x7fb32b86e520>, |
| use_bias=True, |
| use_fast_variance=True, |
| use_scale=True |
| ), |
| out_channels=128, |
| time_mlp=Linear( |
| bias=Param( |
| value=Array(shape=(128,), dtype=dtype('float32')) |
| ), |
| bias_init=<function zeros at 0x7fb32b98c2c0>, |
| dot_general=<function dot_general at 0x7fb32c21a3e0>, |
| dtype=float32, |
| in_features=256, |
| kernel=Param( |
| value=Array(shape=(128, 128), dtype=dtype('float32')) |
| ), |
| kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>, |
| out_features=256, |
| param_dtype=float32, |
| precision=None, |
| promote_dtype=<function promote_dtype at 0x7fb3299d7380>, |
| use_bias=True |
| ) |
| ), AttnBlock( |
| dtype=float32, |
| head_dim=32, |
| k=Conv( |
| bias=Param( |
| value=Array(shape=(64,), dtype=dtype('float32')) |
| ), |
| bias_init=<function zeros at 0x7fb32b98c2c0>, |
| conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>, |
| dtype=float32, |
| feature_group_count=1, |
| in_features=128, |
| input_dilation=1, |
| kernel=Param( |
| value=Array(shape=(1, 1, 64, 64), dtype=dtype('float32')) |
| ), |
| kernel_dilation=1, |
| kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>, |
| kernel_shape=(1, 1, 128, 128), |
| kernel_size=(1, 1), |
| mask=None, |
| out_features=128, |
| padding='SAME', |
| param_dtype=float32, |
| precision=None, |
| promote_dtype=<function promote_dtype at 0x7fb3299d7380>, |
| strides=1, |
| use_bias=True |
| ), |
| norm=GroupNorm( |
| axis_index_groups=None, |
| axis_name=None, |
| bias=Param( |
| value=Array(shape=(64,), dtype=dtype('float32')) |
| ), |
| bias_init=<function zeros at 0x7fb32b98c2c0>, |
| dtype=float32, |
| epsilon=1e-06, |
| feature_axis=-1, |
| group_size=4, |
| num_groups=32, |
| param_dtype=float32, |
| reduction_axes=None, |
| scale=Param( |
| value=Array(shape=(64,), dtype=dtype('float32')) |
| ), |
| scale_init=<function ones at 0x7fb32b86e520>, |
| use_bias=True, |
| use_fast_variance=True, |
| use_scale=True |
| ), |
| num_heads=4, |
| proj_out=Conv( |
| bias=Param( |
| value=Array(shape=(64,), dtype=dtype('float32')) |
| ), |
| bias_init=<function zeros at 0x7fb32b98c2c0>, |
| conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>, |
| dtype=float32, |
| feature_group_count=1, |
| in_features=128, |
| input_dilation=1, |
| kernel=Param( |
| value=Array(shape=(1, 1, 64, 64), dtype=dtype('float32')) |
| ), |
| kernel_dilation=1, |
| kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>, |
| kernel_shape=(1, 1, 128, 128), |
| kernel_size=(1, 1), |
| mask=None, |
| out_features=128, |
| padding='SAME', |
| param_dtype=float32, |
| precision=None, |
| promote_dtype=<function promote_dtype at 0x7fb3299d7380>, |
| strides=1, |
| use_bias=True |
| ), |
| q=Conv( |
| bias=Param( |
| value=Array(shape=(64,), dtype=dtype('float32')) |
| ), |
| bias_init=<function zeros at 0x7fb32b98c2c0>, |
| conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>, |
| dtype=float32, |
| feature_group_count=1, |
| in_features=128, |
| input_dilation=1, |
| kernel=Param( |
| value=Array(shape=(1, 1, 64, 64), dtype=dtype('float32')) |
| ), |
| kernel_dilation=1, |
| kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>, |
| kernel_shape=(1, 1, 128, 128), |
| kernel_size=(1, 1), |
| mask=None, |
| out_features=128, |
| padding='SAME', |
| param_dtype=float32, |
| precision=None, |
| promote_dtype=<function promote_dtype at 0x7fb3299d7380>, |
| strides=1, |
| use_bias=True |
| ), |
| v=Conv( |
| bias=Param( |
| value=Array(shape=(64,), dtype=dtype('float32')) |
| ), |
| bias_init=<function zeros at 0x7fb32b98c2c0>, |
| conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>, |
| dtype=float32, |
| feature_group_count=1, |
| in_features=128, |
| input_dilation=1, |
| kernel=Param( |
| value=Array(shape=(1, 1, 64, 64), dtype=dtype('float32')) |
| ), |
| kernel_dilation=1, |
| kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>, |
| kernel_shape=(1, 1, 128, 128), |
| kernel_size=(1, 1), |
| mask=None, |
| out_features=128, |
| padding='SAME', |
| param_dtype=float32, |
| precision=None, |
| promote_dtype=<function promote_dtype at 0x7fb3299d7380>, |
| strides=1, |
| use_bias=True |
| ) |
| ), Upsample( |
| conv_pixel_shuffle=Conv( |
| bias=Param( |
| value=Array(shape=(256,), dtype=dtype('float32')) |
| ), |
| bias_init=<function zeros at 0x7fb32b98c2c0>, |
| conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>, |
| dtype=float32, |
| feature_group_count=1, |
| in_features=128, |
| input_dilation=1, |
| kernel=Param( |
| value=Array(shape=(3, 3, 64, 256), dtype=dtype('float32')) |
| ), |
| kernel_dilation=1, |
| kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>, |
| kernel_shape=(3, 3, 128, 512), |
| kernel_size=(3, 3), |
| mask=None, |
| out_features=512, |
| padding=(1, 1), |
| param_dtype=float32, |
| precision=None, |
| promote_dtype=<function promote_dtype at 0x7fb3299d7380>, |
| strides=(1, 1), |
| use_bias=True |
| ), |
| conv_resize=Conv( |
| bias=Param( |
| value=Array(shape=(64,), dtype=dtype('float32')) |
| ), |
| bias_init=<function zeros at 0x7fb32b98c2c0>, |
| conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>, |
| dtype=float32, |
| feature_group_count=1, |
| in_features=128, |
| input_dilation=1, |
| kernel=Param( |
| value=Array(shape=(3, 3, 64, 64), dtype=dtype('float32')) |
| ), |
| kernel_dilation=1, |
| kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>, |
| kernel_shape=(3, 3, 128, 128), |
| kernel_size=(3, 3), |
| mask=None, |
| out_features=128, |
| padding=(1, 1), |
| param_dtype=float32, |
| precision=None, |
| promote_dtype=<function promote_dtype at 0x7fb3299d7380>, |
| strides=(1, 1), |
| use_bias=True |
| ), |
| method='pixel_shuffle', |
| pixel_shuffle=PixelShuffle( |
| scale=2 |
| ), |
| scale_factor=2 |
| )] |
| ), TimestepEmbedSequential( |
| layers=[ResnetBlock( |
| activation=<PjitFunction of <function silu at 0x7fb32b8440e0>>, |
| conv1=Conv( |
| bias=Param( |
| value=Array(shape=(32,), dtype=dtype('float32')) |
| ), |
| bias_init=<function zeros at 0x7fb32b98c2c0>, |
| conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>, |
| dtype=float32, |
| feature_group_count=1, |
| in_features=192, |
| input_dilation=1, |
| kernel=Param( |
| value=Array(shape=(3, 3, 96, 32), dtype=dtype('float32')) |
| ), |
| kernel_dilation=1, |
| kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>, |
| kernel_shape=(3, 3, 192, 64), |
| kernel_size=(3, 3), |
| mask=None, |
| out_features=64, |
| padding=(1, 1), |
| param_dtype=float32, |
| precision=None, |
| promote_dtype=<function promote_dtype at 0x7fb3299d7380>, |
| strides=(1, 1), |
| use_bias=True |
| ), |
| conv2=Conv( |
| bias=Param( |
| value=Array(shape=(32,), dtype=dtype('float32')) |
| ), |
| bias_init=<function zeros at 0x7fb32b98c2c0>, |
| conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>, |
| dtype=float32, |
| feature_group_count=1, |
| in_features=64, |
| input_dilation=1, |
| kernel=Param( |
| value=Array(shape=(3, 3, 32, 32), dtype=dtype('float32')) |
| ), |
| kernel_dilation=1, |
| kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>, |
| kernel_shape=(3, 3, 64, 64), |
| kernel_size=(3, 3), |
| mask=None, |
| out_features=64, |
| padding=(1, 1), |
| param_dtype=float32, |
| precision=None, |
| promote_dtype=<function promote_dtype at 0x7fb3299d7380>, |
| strides=(1, 1), |
| use_bias=True |
| ), |
| dropout=Dropout( |
| broadcast_dims=(), |
| deterministic=False, |
| rate=0.1, |
| rng_collection='dropout', |
| rngs=RngStream( |
| count=RngCount( |
| value=Array(0, dtype=uint32), |
| tag='default' |
| ), |
| key=RngKey( |
| value=Array((), dtype=key<fry>) overlaying: |
| [2550820645 2818876438], |
| tag='default' |
| ), |
| tag='default' |
| ) |
| ), |
| embedding_dim=256, |
| in_channels=192, |
| nin_shortcut=Conv( |
| bias=Param( |
| value=Array(shape=(32,), dtype=dtype('float32')) |
| ), |
| bias_init=<function zeros at 0x7fb32b98c2c0>, |
| conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>, |
| dtype=float32, |
| feature_group_count=1, |
| in_features=192, |
| input_dilation=1, |
| kernel=Param( |
| value=Array(shape=(1, 1, 96, 32), dtype=dtype('float32')) |
| ), |
| kernel_dilation=1, |
| kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>, |
| kernel_shape=(1, 1, 192, 64), |
| kernel_size=(1, 1), |
| mask=None, |
| out_features=64, |
| padding=(0, 0), |
| param_dtype=float32, |
| precision=None, |
| promote_dtype=<function promote_dtype at 0x7fb3299d7380>, |
| strides=(1, 1), |
| use_bias=True |
| ), |
| norm1=GroupNorm( |
| axis_index_groups=None, |
| axis_name=None, |
| bias=Param( |
| value=Array(shape=(96,), dtype=dtype('float32')) |
| ), |
| bias_init=<function zeros at 0x7fb32b98c2c0>, |
| dtype=float32, |
| epsilon=1e-06, |
| feature_axis=-1, |
| group_size=6, |
| num_groups=32, |
| param_dtype=float32, |
| reduction_axes=None, |
| scale=Param( |
| value=Array(shape=(96,), dtype=dtype('float32')) |
| ), |
| scale_init=<function ones at 0x7fb32b86e520>, |
| use_bias=True, |
| use_fast_variance=True, |
| use_scale=True |
| ), |
| norm2=GroupNorm( |
| axis_index_groups=None, |
| axis_name=None, |
| bias=Param( |
| value=Array(shape=(32,), dtype=dtype('float32')) |
| ), |
| bias_init=<function zeros at 0x7fb32b98c2c0>, |
| dtype=float32, |
| epsilon=1e-06, |
| feature_axis=-1, |
| group_size=2, |
| num_groups=32, |
| param_dtype=float32, |
| reduction_axes=None, |
| scale=Param( |
| value=Array(shape=(32,), dtype=dtype('float32')) |
| ), |
| scale_init=<function ones at 0x7fb32b86e520>, |
| use_bias=True, |
| use_fast_variance=True, |
| use_scale=True |
| ), |
| out_channels=64, |
| time_mlp=Linear( |
| bias=Param( |
| value=Array(shape=(64,), dtype=dtype('float32')) |
| ), |
| bias_init=<function zeros at 0x7fb32b98c2c0>, |
| dot_general=<function dot_general at 0x7fb32c21a3e0>, |
| dtype=float32, |
| in_features=256, |
| kernel=Param( |
| value=Array(shape=(128, 64), dtype=dtype('float32')) |
| ), |
| kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>, |
| out_features=128, |
| param_dtype=float32, |
| precision=None, |
| promote_dtype=<function promote_dtype at 0x7fb3299d7380>, |
| use_bias=True |
| ) |
| )] |
| ), TimestepEmbedSequential( |
| layers=[ResnetBlock( |
| activation=<PjitFunction of <function silu at 0x7fb32b8440e0>>, |
| conv1=Conv( |
| bias=Param( |
| value=Array(shape=(32,), dtype=dtype('float32')) |
| ), |
| bias_init=<function zeros at 0x7fb32b98c2c0>, |
| conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>, |
| dtype=float32, |
| feature_group_count=1, |
| in_features=128, |
| input_dilation=1, |
| kernel=Param( |
| value=Array(shape=(3, 3, 64, 32), dtype=dtype('float32')) |
| ), |
| kernel_dilation=1, |
| kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>, |
| kernel_shape=(3, 3, 128, 64), |
| kernel_size=(3, 3), |
| mask=None, |
| out_features=64, |
| padding=(1, 1), |
| param_dtype=float32, |
| precision=None, |
| promote_dtype=<function promote_dtype at 0x7fb3299d7380>, |
| strides=(1, 1), |
| use_bias=True |
| ), |
| conv2=Conv( |
| bias=Param( |
| value=Array(shape=(32,), dtype=dtype('float32')) |
| ), |
| bias_init=<function zeros at 0x7fb32b98c2c0>, |
| conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>, |
| dtype=float32, |
| feature_group_count=1, |
| in_features=64, |
| input_dilation=1, |
| kernel=Param( |
| value=Array(shape=(3, 3, 32, 32), dtype=dtype('float32')) |
| ), |
| kernel_dilation=1, |
| kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>, |
| kernel_shape=(3, 3, 64, 64), |
| kernel_size=(3, 3), |
| mask=None, |
| out_features=64, |
| padding=(1, 1), |
| param_dtype=float32, |
| precision=None, |
| promote_dtype=<function promote_dtype at 0x7fb3299d7380>, |
| strides=(1, 1), |
| use_bias=True |
| ), |
| dropout=Dropout( |
| broadcast_dims=(), |
| deterministic=False, |
| rate=0.1, |
| rng_collection='dropout', |
| rngs=RngStream( |
| count=RngCount( |
| value=Array(0, dtype=uint32), |
| tag='default' |
| ), |
| key=RngKey( |
| value=Array((), dtype=key<fry>) overlaying: |
| [1975238715 3717004500], |
| tag='default' |
| ), |
| tag='default' |
| ) |
| ), |
| embedding_dim=256, |
| in_channels=128, |
| nin_shortcut=Conv( |
| bias=Param( |
| value=Array(shape=(32,), dtype=dtype('float32')) |
| ), |
| bias_init=<function zeros at 0x7fb32b98c2c0>, |
| conv_general_dilated=<function conv_general_dilated at 0x7fb32c260ea0>, |
| dtype=float32, |
| feature_group_count=1, |
| in_features=128, |
| input_dilation=1, |
| kernel=Param( |
| value=Array(shape=(1, 1, 64, 32), dtype=dtype('float32')) |
| ), |
| kernel_dilation=1, |
| kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>, |
| kernel_shape=(1, 1, 128, 64), |
| kernel_size=(1, 1), |
| mask=None, |
| out_features=64, |
| padding=(0, 0), |
| param_dtype=float32, |
| precision=None, |
| promote_dtype=<function promote_dtype at 0x7fb3299d7380>, |
| strides=(1, 1), |
| use_bias=True |
| ), |
| norm1=GroupNorm( |
| axis_index_groups=None, |
| axis_name=None, |
| bias=Param( |
| value=Array(shape=(64,), dtype=dtype('float32')) |
| ), |
| bias_init=<function zeros at 0x7fb32b98c2c0>, |
| dtype=float32, |
| epsilon=1e-06, |
| feature_axis=-1, |
| group_size=4, |
| num_groups=32, |
| param_dtype=float32, |
| reduction_axes=None, |
| scale=Param( |
| value=Array(shape=(64,), dtype=dtype('float32')) |
| ), |
| scale_init=<function ones at 0x7fb32b86e520>, |
| use_bias=True, |
| use_fast_variance=True, |
| use_scale=True |
| ), |
| norm2=GroupNorm( |
| axis_index_groups=None, |
| axis_name=None, |
| bias=Param( |
| value=Array(shape=(32,), dtype=dtype('float32')) |
| ), |
| bias_init=<function zeros at 0x7fb32b98c2c0>, |
| dtype=float32, |
| epsilon=1e-06, |
| feature_axis=-1, |
| group_size=2, |
| num_groups=32, |
| param_dtype=float32, |
| reduction_axes=None, |
| scale=Param( |
| value=Array(shape=(32,), dtype=dtype('float32')) |
| ), |
| scale_init=<function ones at 0x7fb32b86e520>, |
| use_bias=True, |
| use_fast_variance=True, |
| use_scale=True |
| ), |
| out_channels=64, |
| time_mlp=Linear( |
| bias=Param( |
| value=Array(shape=(64,), dtype=dtype('float32')) |
| ), |
| bias_init=<function zeros at 0x7fb32b98c2c0>, |
| dot_general=<function dot_general at 0x7fb32c21a3e0>, |
| dtype=float32, |
| in_features=256, |
| kernel=Param( |
| value=Array(shape=(128, 64), dtype=dtype('float32')) |
| ), |
| kernel_init=<function variance_scaling.<locals>.init at 0x7fb3299d7060>, |
| out_features=128, |
| param_dtype=float32, |
| precision=None, |
| promote_dtype=<function promote_dtype at 0x7fb3299d7380>, |
| use_bias=True |
| ) |
| )] |
| )], |
| ) |
|
|