| {"tree_metadata": {"('decoder', 'conv_in', 'bias', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "conv_in", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('decoder', 'conv_in', 'kernel', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "conv_in", "key_type": 2}, {"key": "kernel", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 16, 512]}}, "('decoder', 'conv_norm_out', 'bias', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "conv_norm_out", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [128]}}, "('decoder', 'conv_norm_out', 'scale', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "conv_norm_out", "key_type": 2}, {"key": "scale", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [128]}}, "('decoder', 'conv_out', 'bias', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "conv_out", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3]}}, "('decoder', 'conv_out', 'kernel', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "conv_out", "key_type": 2}, {"key": "kernel", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 128, 3]}}, "('decoder', 'mid_block', 'attentions', '0', 'group_norm', 'bias', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "attentions", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "group_norm", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('decoder', 'mid_block', 'attentions', '0', 'group_norm', 'scale', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "attentions", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "group_norm", "key_type": 2}, {"key": "scale", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('decoder', 'mid_block', 'attentions', '0', 'to_k', 'bias', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "attentions", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "to_k", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('decoder', 'mid_block', 'attentions', '0', 'to_k', 'kernel', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "attentions", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "to_k", "key_type": 2}, {"key": "kernel", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1, 1, 512, 512]}}, "('decoder', 'mid_block', 'attentions', '0', 'to_out', '0', 'bias', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "attentions", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "to_out", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('decoder', 'mid_block', 'attentions', '0', 'to_out', '0', 'kernel', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "attentions", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "to_out", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "kernel", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1, 1, 512, 512]}}, "('decoder', 'mid_block', 'attentions', '0', 'to_q', 'bias', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "attentions", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "to_q", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('decoder', 'mid_block', 'attentions', '0', 'to_q', 'kernel', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "attentions", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "to_q", "key_type": 2}, {"key": "kernel", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1, 1, 512, 512]}}, "('decoder', 'mid_block', 'attentions', '0', 'to_v', 'bias', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "attentions", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "to_v", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('decoder', 'mid_block', 'attentions', '0', 'to_v', 'kernel', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "attentions", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "to_v", "key_type": 2}, {"key": "kernel", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1, 1, 512, 512]}}, "('decoder', 'mid_block', 'resnets', '0', 'conv1', 'bias', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('decoder', 'mid_block', 'resnets', '0', 'conv1', 'kernel', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "kernel", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 512, 512]}}, "('decoder', 'mid_block', 'resnets', '0', 'conv2', 'bias', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('decoder', 'mid_block', 'resnets', '0', 'conv2', 'kernel', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "kernel", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 512, 512]}}, "('decoder', 'mid_block', 'resnets', '0', 'norm1', 'bias', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('decoder', 'mid_block', 'resnets', '0', 'norm1', 'scale', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "scale", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('decoder', 'mid_block', 'resnets', '0', 'norm2', 'bias', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('decoder', 'mid_block', 'resnets', '0', 'norm2', 'scale', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "scale", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('decoder', 'mid_block', 'resnets', '1', 'conv1', 'bias', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('decoder', 'mid_block', 'resnets', '1', 'conv1', 'kernel', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "kernel", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 512, 512]}}, "('decoder', 'mid_block', 'resnets', '1', 'conv2', 'bias', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('decoder', 'mid_block', 'resnets', '1', 'conv2', 'kernel', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "kernel", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 512, 512]}}, "('decoder', 'mid_block', 'resnets', '1', 'norm1', 'bias', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('decoder', 'mid_block', 'resnets', '1', 'norm1', 'scale', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "scale", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('decoder', 'mid_block', 'resnets', '1', 'norm2', 'bias', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('decoder', 'mid_block', 'resnets', '1', 'norm2', 'scale', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "scale", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('decoder', 'up_blocks', '0', 'resnets', '0', 'conv1', 'bias', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('decoder', 'up_blocks', '0', 'resnets', '0', 'conv1', 'kernel', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "kernel", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 512, 512]}}, "('decoder', 'up_blocks', '0', 'resnets', '0', 'conv2', 'bias', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('decoder', 'up_blocks', '0', 'resnets', '0', 'conv2', 'kernel', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "kernel", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 512, 512]}}, "('decoder', 'up_blocks', '0', 'resnets', '0', 'norm1', 'bias', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('decoder', 'up_blocks', '0', 'resnets', '0', 'norm1', 'scale', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "scale", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('decoder', 'up_blocks', '0', 'resnets', '0', 'norm2', 'bias', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('decoder', 'up_blocks', '0', 'resnets', '0', 'norm2', 'scale', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "scale", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('decoder', 'up_blocks', '0', 'resnets', '1', 'conv1', 'bias', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('decoder', 'up_blocks', '0', 'resnets', '1', 'conv1', 'kernel', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "kernel", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 512, 512]}}, "('decoder', 'up_blocks', '0', 'resnets', '1', 'conv2', 'bias', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('decoder', 'up_blocks', '0', 'resnets', '1', 'conv2', 'kernel', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "kernel", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 512, 512]}}, "('decoder', 'up_blocks', '0', 'resnets', '1', 'norm1', 'bias', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('decoder', 'up_blocks', '0', 'resnets', '1', 'norm1', 'scale', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "scale", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('decoder', 'up_blocks', '0', 'resnets', '1', 'norm2', 'bias', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('decoder', 'up_blocks', '0', 'resnets', '1', 'norm2', 'scale', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "scale", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('decoder', 'up_blocks', '0', 'resnets', '2', 'conv1', 'bias', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('decoder', 'up_blocks', '0', 'resnets', '2', 'conv1', 'kernel', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "kernel", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 512, 512]}}, "('decoder', 'up_blocks', '0', 'resnets', '2', 'conv2', 'bias', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('decoder', 'up_blocks', '0', 'resnets', '2', 'conv2', 'kernel', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "kernel", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 512, 512]}}, "('decoder', 'up_blocks', '0', 'resnets', '2', 'norm1', 'bias', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('decoder', 'up_blocks', '0', 'resnets', '2', 'norm1', 'scale', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "scale", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('decoder', 'up_blocks', '0', 'resnets', '2', 'norm2', 'bias', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('decoder', 'up_blocks', '0', 'resnets', '2', 'norm2', 'scale', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "scale", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('decoder', 'up_blocks', '0', 'upsamplers', '0', 'bias', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "upsamplers", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('decoder', 'up_blocks', '0', 'upsamplers', '0', 'kernel', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "upsamplers", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "kernel", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 512, 512]}}, "('decoder', 'up_blocks', '1', 'resnets', '0', 'conv1', 'bias', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('decoder', 'up_blocks', '1', 'resnets', '0', 'conv1', 'kernel', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "kernel", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 512, 512]}}, "('decoder', 'up_blocks', '1', 'resnets', '0', 'conv2', 'bias', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('decoder', 'up_blocks', '1', 'resnets', '0', 'conv2', 'kernel', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "kernel", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 512, 512]}}, "('decoder', 'up_blocks', '1', 'resnets', '0', 'norm1', 'bias', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('decoder', 'up_blocks', '1', 'resnets', '0', 'norm1', 'scale', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "scale", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('decoder', 'up_blocks', '1', 'resnets', '0', 'norm2', 'bias', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('decoder', 'up_blocks', '1', 'resnets', '0', 'norm2', 'scale', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "scale", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('decoder', 'up_blocks', '1', 'resnets', '1', 'conv1', 'bias', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('decoder', 'up_blocks', '1', 'resnets', '1', 'conv1', 'kernel', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "kernel", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 512, 512]}}, "('decoder', 'up_blocks', '1', 'resnets', '1', 'conv2', 'bias', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('decoder', 'up_blocks', '1', 'resnets', '1', 'conv2', 'kernel', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "kernel", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 512, 512]}}, "('decoder', 'up_blocks', '1', 'resnets', '1', 'norm1', 'bias', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('decoder', 'up_blocks', '1', 'resnets', '1', 'norm1', 'scale', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "scale", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('decoder', 'up_blocks', '1', 'resnets', '1', 'norm2', 'bias', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('decoder', 'up_blocks', '1', 'resnets', '1', 'norm2', 'scale', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "scale", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('decoder', 'up_blocks', '1', 'resnets', '2', 'conv1', 'bias', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('decoder', 'up_blocks', '1', 'resnets', '2', 'conv1', 'kernel', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "kernel", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 512, 512]}}, "('decoder', 'up_blocks', '1', 'resnets', '2', 'conv2', 'bias', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('decoder', 'up_blocks', '1', 'resnets', '2', 'conv2', 'kernel', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "kernel", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 512, 512]}}, "('decoder', 'up_blocks', '1', 'resnets', '2', 'norm1', 'bias', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('decoder', 'up_blocks', '1', 'resnets', '2', 'norm1', 'scale', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "scale", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('decoder', 'up_blocks', '1', 'resnets', '2', 'norm2', 'bias', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('decoder', 'up_blocks', '1', 'resnets', '2', 'norm2', 'scale', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "scale", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('decoder', 'up_blocks', '1', 'upsamplers', '0', 'bias', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "upsamplers", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('decoder', 'up_blocks', '1', 'upsamplers', '0', 'kernel', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "upsamplers", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "kernel", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 512, 512]}}, "('decoder', 'up_blocks', '2', 'resnets', '0', 'conv1', 'bias', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [256]}}, "('decoder', 'up_blocks', '2', 'resnets', '0', 'conv1', 'kernel', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "kernel", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 512, 256]}}, "('decoder', 'up_blocks', '2', 'resnets', '0', 'conv2', 'bias', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [256]}}, "('decoder', 'up_blocks', '2', 'resnets', '0', 'conv2', 'kernel', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "kernel", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 256, 256]}}, "('decoder', 'up_blocks', '2', 'resnets', '0', 'conv_shortcut', 'bias', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "conv_shortcut", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [256]}}, "('decoder', 'up_blocks', '2', 'resnets', '0', 'conv_shortcut', 'kernel', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "conv_shortcut", "key_type": 2}, {"key": "kernel", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1, 1, 512, 256]}}, "('decoder', 'up_blocks', '2', 'resnets', '0', 'norm1', 'bias', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('decoder', 'up_blocks', '2', 'resnets', '0', 'norm1', 'scale', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "scale", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('decoder', 'up_blocks', '2', 'resnets', '0', 'norm2', 'bias', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [256]}}, "('decoder', 'up_blocks', '2', 'resnets', '0', 'norm2', 'scale', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "scale", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [256]}}, "('decoder', 'up_blocks', '2', 'resnets', '1', 'conv1', 'bias', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [256]}}, "('decoder', 'up_blocks', '2', 'resnets', '1', 'conv1', 'kernel', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "kernel", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 256, 256]}}, "('decoder', 'up_blocks', '2', 'resnets', '1', 'conv2', 'bias', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [256]}}, "('decoder', 'up_blocks', '2', 'resnets', '1', 'conv2', 'kernel', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "kernel", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 256, 256]}}, "('decoder', 'up_blocks', '2', 'resnets', '1', 'norm1', 'bias', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [256]}}, "('decoder', 'up_blocks', '2', 'resnets', '1', 'norm1', 'scale', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "scale", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [256]}}, "('decoder', 'up_blocks', '2', 'resnets', '1', 'norm2', 'bias', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [256]}}, "('decoder', 'up_blocks', '2', 'resnets', '1', 'norm2', 'scale', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "scale", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [256]}}, "('decoder', 'up_blocks', '2', 'resnets', '2', 'conv1', 'bias', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [256]}}, "('decoder', 'up_blocks', '2', 'resnets', '2', 'conv1', 'kernel', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "kernel", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 256, 256]}}, "('decoder', 'up_blocks', '2', 'resnets', '2', 'conv2', 'bias', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [256]}}, "('decoder', 'up_blocks', '2', 'resnets', '2', 'conv2', 'kernel', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "kernel", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 256, 256]}}, "('decoder', 'up_blocks', '2', 'resnets', '2', 'norm1', 'bias', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [256]}}, "('decoder', 'up_blocks', '2', 'resnets', '2', 'norm1', 'scale', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "scale", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [256]}}, "('decoder', 'up_blocks', '2', 'resnets', '2', 'norm2', 'bias', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [256]}}, "('decoder', 'up_blocks', '2', 'resnets', '2', 'norm2', 'scale', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "scale", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [256]}}, "('decoder', 'up_blocks', '2', 'upsamplers', '0', 'bias', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "upsamplers", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [256]}}, "('decoder', 'up_blocks', '2', 'upsamplers', '0', 'kernel', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "upsamplers", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "kernel", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 256, 256]}}, "('decoder', 'up_blocks', '3', 'resnets', '0', 'conv1', 'bias', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "3", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [128]}}, "('decoder', 'up_blocks', '3', 'resnets', '0', 'conv1', 'kernel', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "3", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "kernel", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 256, 128]}}, "('decoder', 'up_blocks', '3', 'resnets', '0', 'conv2', 'bias', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "3", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [128]}}, "('decoder', 'up_blocks', '3', 'resnets', '0', 'conv2', 'kernel', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "3", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "kernel", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 128, 128]}}, "('decoder', 'up_blocks', '3', 'resnets', '0', 'conv_shortcut', 'bias', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "3", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "conv_shortcut", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [128]}}, "('decoder', 'up_blocks', '3', 'resnets', '0', 'conv_shortcut', 'kernel', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "3", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "conv_shortcut", "key_type": 2}, {"key": "kernel", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1, 1, 256, 128]}}, "('decoder', 'up_blocks', '3', 'resnets', '0', 'norm1', 'bias', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "3", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [256]}}, "('decoder', 'up_blocks', '3', 'resnets', '0', 'norm1', 'scale', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "3", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "scale", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [256]}}, "('decoder', 'up_blocks', '3', 'resnets', '0', 'norm2', 'bias', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "3", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [128]}}, "('decoder', 'up_blocks', '3', 'resnets', '0', 'norm2', 'scale', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "3", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "scale", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [128]}}, "('decoder', 'up_blocks', '3', 'resnets', '1', 'conv1', 'bias', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "3", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [128]}}, "('decoder', 'up_blocks', '3', 'resnets', '1', 'conv1', 'kernel', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "3", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "kernel", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 128, 128]}}, "('decoder', 'up_blocks', '3', 'resnets', '1', 'conv2', 'bias', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "3", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [128]}}, "('decoder', 'up_blocks', '3', 'resnets', '1', 'conv2', 'kernel', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "3", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "kernel", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 128, 128]}}, "('decoder', 'up_blocks', '3', 'resnets', '1', 'norm1', 'bias', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "3", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [128]}}, "('decoder', 'up_blocks', '3', 'resnets', '1', 'norm1', 'scale', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "3", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "scale", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [128]}}, "('decoder', 'up_blocks', '3', 'resnets', '1', 'norm2', 'bias', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "3", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [128]}}, "('decoder', 'up_blocks', '3', 'resnets', '1', 'norm2', 'scale', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "3", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "scale", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [128]}}, "('decoder', 'up_blocks', '3', 'resnets', '2', 'conv1', 'bias', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "3", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [128]}}, "('decoder', 'up_blocks', '3', 'resnets', '2', 'conv1', 'kernel', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "3", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "kernel", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 128, 128]}}, "('decoder', 'up_blocks', '3', 'resnets', '2', 'conv2', 'bias', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "3", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [128]}}, "('decoder', 'up_blocks', '3', 'resnets', '2', 'conv2', 'kernel', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "3", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "kernel", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 128, 128]}}, "('decoder', 'up_blocks', '3', 'resnets', '2', 'norm1', 'bias', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "3", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [128]}}, "('decoder', 'up_blocks', '3', 'resnets', '2', 'norm1', 'scale', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "3", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "scale", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [128]}}, "('decoder', 'up_blocks', '3', 'resnets', '2', 'norm2', 'bias', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "3", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [128]}}, "('decoder', 'up_blocks', '3', 'resnets', '2', 'norm2', 'scale', 'value')": {"key_metadata": [{"key": "decoder", "key_type": 2}, {"key": "up_blocks", "key_type": 2}, {"key": "3", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "scale", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [128]}}, "('encoder', 'conv_in', 'bias', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "conv_in", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [128]}}, "('encoder', 'conv_in', 'kernel', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "conv_in", "key_type": 2}, {"key": "kernel", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 3, 128]}}, "('encoder', 'conv_norm_out', 'bias', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "conv_norm_out", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('encoder', 'conv_norm_out', 'scale', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "conv_norm_out", "key_type": 2}, {"key": "scale", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('encoder', 'conv_out', 'bias', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "conv_out", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [32]}}, "('encoder', 'conv_out', 'kernel', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "conv_out", "key_type": 2}, {"key": "kernel", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 512, 32]}}, "('encoder', 'down_blocks', '0', 'downsamplers', '0', 'bias', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "down_blocks", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "downsamplers", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [128]}}, "('encoder', 'down_blocks', '0', 'downsamplers', '0', 'kernel', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "down_blocks", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "downsamplers", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "kernel", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 128, 128]}}, "('encoder', 'down_blocks', '0', 'resnets', '0', 'conv1', 'bias', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "down_blocks", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [128]}}, "('encoder', 'down_blocks', '0', 'resnets', '0', 'conv1', 'kernel', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "down_blocks", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "kernel", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 128, 128]}}, "('encoder', 'down_blocks', '0', 'resnets', '0', 'conv2', 'bias', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "down_blocks", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [128]}}, "('encoder', 'down_blocks', '0', 'resnets', '0', 'conv2', 'kernel', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "down_blocks", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "kernel", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 128, 128]}}, "('encoder', 'down_blocks', '0', 'resnets', '0', 'norm1', 'bias', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "down_blocks", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [128]}}, "('encoder', 'down_blocks', '0', 'resnets', '0', 'norm1', 'scale', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "down_blocks", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "scale", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [128]}}, "('encoder', 'down_blocks', '0', 'resnets', '0', 'norm2', 'bias', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "down_blocks", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [128]}}, "('encoder', 'down_blocks', '0', 'resnets', '0', 'norm2', 'scale', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "down_blocks", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "scale", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [128]}}, "('encoder', 'down_blocks', '0', 'resnets', '1', 'conv1', 'bias', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "down_blocks", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [128]}}, "('encoder', 'down_blocks', '0', 'resnets', '1', 'conv1', 'kernel', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "down_blocks", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "kernel", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 128, 128]}}, "('encoder', 'down_blocks', '0', 'resnets', '1', 'conv2', 'bias', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "down_blocks", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [128]}}, "('encoder', 'down_blocks', '0', 'resnets', '1', 'conv2', 'kernel', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "down_blocks", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "kernel", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 128, 128]}}, "('encoder', 'down_blocks', '0', 'resnets', '1', 'norm1', 'bias', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "down_blocks", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [128]}}, "('encoder', 'down_blocks', '0', 'resnets', '1', 'norm1', 'scale', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "down_blocks", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "scale", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [128]}}, "('encoder', 'down_blocks', '0', 'resnets', '1', 'norm2', 'bias', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "down_blocks", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [128]}}, "('encoder', 'down_blocks', '0', 'resnets', '1', 'norm2', 'scale', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "down_blocks", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "scale", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [128]}}, "('encoder', 'down_blocks', '1', 'downsamplers', '0', 'bias', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "down_blocks", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "downsamplers", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [256]}}, "('encoder', 'down_blocks', '1', 'downsamplers', '0', 'kernel', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "down_blocks", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "downsamplers", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "kernel", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 256, 256]}}, "('encoder', 'down_blocks', '1', 'resnets', '0', 'conv1', 'bias', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "down_blocks", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [256]}}, "('encoder', 'down_blocks', '1', 'resnets', '0', 'conv1', 'kernel', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "down_blocks", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "kernel", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 128, 256]}}, "('encoder', 'down_blocks', '1', 'resnets', '0', 'conv2', 'bias', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "down_blocks", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [256]}}, "('encoder', 'down_blocks', '1', 'resnets', '0', 'conv2', 'kernel', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "down_blocks", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "kernel", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 256, 256]}}, "('encoder', 'down_blocks', '1', 'resnets', '0', 'conv_shortcut', 'bias', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "down_blocks", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "conv_shortcut", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [256]}}, "('encoder', 'down_blocks', '1', 'resnets', '0', 'conv_shortcut', 'kernel', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "down_blocks", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "conv_shortcut", "key_type": 2}, {"key": "kernel", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1, 1, 128, 256]}}, "('encoder', 'down_blocks', '1', 'resnets', '0', 'norm1', 'bias', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "down_blocks", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [128]}}, "('encoder', 'down_blocks', '1', 'resnets', '0', 'norm1', 'scale', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "down_blocks", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "scale", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [128]}}, "('encoder', 'down_blocks', '1', 'resnets', '0', 'norm2', 'bias', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "down_blocks", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [256]}}, "('encoder', 'down_blocks', '1', 'resnets', '0', 'norm2', 'scale', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "down_blocks", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "scale", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [256]}}, "('encoder', 'down_blocks', '1', 'resnets', '1', 'conv1', 'bias', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "down_blocks", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [256]}}, "('encoder', 'down_blocks', '1', 'resnets', '1', 'conv1', 'kernel', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "down_blocks", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "kernel", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 256, 256]}}, "('encoder', 'down_blocks', '1', 'resnets', '1', 'conv2', 'bias', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "down_blocks", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [256]}}, "('encoder', 'down_blocks', '1', 'resnets', '1', 'conv2', 'kernel', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "down_blocks", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "kernel", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 256, 256]}}, "('encoder', 'down_blocks', '1', 'resnets', '1', 'norm1', 'bias', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "down_blocks", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [256]}}, "('encoder', 'down_blocks', '1', 'resnets', '1', 'norm1', 'scale', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "down_blocks", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "scale", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [256]}}, "('encoder', 'down_blocks', '1', 'resnets', '1', 'norm2', 'bias', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "down_blocks", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [256]}}, "('encoder', 'down_blocks', '1', 'resnets', '1', 'norm2', 'scale', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "down_blocks", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "scale", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [256]}}, "('encoder', 'down_blocks', '2', 'downsamplers', '0', 'bias', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "down_blocks", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "downsamplers", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('encoder', 'down_blocks', '2', 'downsamplers', '0', 'kernel', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "down_blocks", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "downsamplers", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "kernel", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 512, 512]}}, "('encoder', 'down_blocks', '2', 'resnets', '0', 'conv1', 'bias', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "down_blocks", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('encoder', 'down_blocks', '2', 'resnets', '0', 'conv1', 'kernel', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "down_blocks", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "kernel", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 256, 512]}}, "('encoder', 'down_blocks', '2', 'resnets', '0', 'conv2', 'bias', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "down_blocks", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('encoder', 'down_blocks', '2', 'resnets', '0', 'conv2', 'kernel', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "down_blocks", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "kernel", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 512, 512]}}, "('encoder', 'down_blocks', '2', 'resnets', '0', 'conv_shortcut', 'bias', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "down_blocks", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "conv_shortcut", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('encoder', 'down_blocks', '2', 'resnets', '0', 'conv_shortcut', 'kernel', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "down_blocks", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "conv_shortcut", "key_type": 2}, {"key": "kernel", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1, 1, 256, 512]}}, "('encoder', 'down_blocks', '2', 'resnets', '0', 'norm1', 'bias', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "down_blocks", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [256]}}, "('encoder', 'down_blocks', '2', 'resnets', '0', 'norm1', 'scale', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "down_blocks", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "scale", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [256]}}, "('encoder', 'down_blocks', '2', 'resnets', '0', 'norm2', 'bias', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "down_blocks", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('encoder', 'down_blocks', '2', 'resnets', '0', 'norm2', 'scale', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "down_blocks", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "scale", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('encoder', 'down_blocks', '2', 'resnets', '1', 'conv1', 'bias', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "down_blocks", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('encoder', 'down_blocks', '2', 'resnets', '1', 'conv1', 'kernel', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "down_blocks", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "kernel", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 512, 512]}}, "('encoder', 'down_blocks', '2', 'resnets', '1', 'conv2', 'bias', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "down_blocks", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('encoder', 'down_blocks', '2', 'resnets', '1', 'conv2', 'kernel', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "down_blocks", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "kernel", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 512, 512]}}, "('encoder', 'down_blocks', '2', 'resnets', '1', 'norm1', 'bias', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "down_blocks", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('encoder', 'down_blocks', '2', 'resnets', '1', 'norm1', 'scale', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "down_blocks", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "scale", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('encoder', 'down_blocks', '2', 'resnets', '1', 'norm2', 'bias', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "down_blocks", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('encoder', 'down_blocks', '2', 'resnets', '1', 'norm2', 'scale', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "down_blocks", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "scale", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('encoder', 'down_blocks', '3', 'resnets', '0', 'conv1', 'bias', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "down_blocks", "key_type": 2}, {"key": "3", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('encoder', 'down_blocks', '3', 'resnets', '0', 'conv1', 'kernel', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "down_blocks", "key_type": 2}, {"key": "3", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "kernel", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 512, 512]}}, "('encoder', 'down_blocks', '3', 'resnets', '0', 'conv2', 'bias', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "down_blocks", "key_type": 2}, {"key": "3", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('encoder', 'down_blocks', '3', 'resnets', '0', 'conv2', 'kernel', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "down_blocks", "key_type": 2}, {"key": "3", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "kernel", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 512, 512]}}, "('encoder', 'down_blocks', '3', 'resnets', '0', 'norm1', 'bias', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "down_blocks", "key_type": 2}, {"key": "3", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('encoder', 'down_blocks', '3', 'resnets', '0', 'norm1', 'scale', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "down_blocks", "key_type": 2}, {"key": "3", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "scale", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('encoder', 'down_blocks', '3', 'resnets', '0', 'norm2', 'bias', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "down_blocks", "key_type": 2}, {"key": "3", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('encoder', 'down_blocks', '3', 'resnets', '0', 'norm2', 'scale', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "down_blocks", "key_type": 2}, {"key": "3", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "scale", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('encoder', 'down_blocks', '3', 'resnets', '1', 'conv1', 'bias', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "down_blocks", "key_type": 2}, {"key": "3", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('encoder', 'down_blocks', '3', 'resnets', '1', 'conv1', 'kernel', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "down_blocks", "key_type": 2}, {"key": "3", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "kernel", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 512, 512]}}, "('encoder', 'down_blocks', '3', 'resnets', '1', 'conv2', 'bias', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "down_blocks", "key_type": 2}, {"key": "3", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('encoder', 'down_blocks', '3', 'resnets', '1', 'conv2', 'kernel', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "down_blocks", "key_type": 2}, {"key": "3", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "kernel", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 512, 512]}}, "('encoder', 'down_blocks', '3', 'resnets', '1', 'norm1', 'bias', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "down_blocks", "key_type": 2}, {"key": "3", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('encoder', 'down_blocks', '3', 'resnets', '1', 'norm1', 'scale', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "down_blocks", "key_type": 2}, {"key": "3", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "scale", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('encoder', 'down_blocks', '3', 'resnets', '1', 'norm2', 'bias', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "down_blocks", "key_type": 2}, {"key": "3", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('encoder', 'down_blocks', '3', 'resnets', '1', 'norm2', 'scale', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "down_blocks", "key_type": 2}, {"key": "3", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "scale", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('encoder', 'mid_block', 'attentions', '0', 'group_norm', 'bias', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "attentions", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "group_norm", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('encoder', 'mid_block', 'attentions', '0', 'group_norm', 'scale', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "attentions", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "group_norm", "key_type": 2}, {"key": "scale", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('encoder', 'mid_block', 'attentions', '0', 'to_k', 'bias', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "attentions", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "to_k", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('encoder', 'mid_block', 'attentions', '0', 'to_k', 'kernel', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "attentions", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "to_k", "key_type": 2}, {"key": "kernel", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1, 1, 512, 512]}}, "('encoder', 'mid_block', 'attentions', '0', 'to_out', '0', 'bias', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "attentions", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "to_out", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('encoder', 'mid_block', 'attentions', '0', 'to_out', '0', 'kernel', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "attentions", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "to_out", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "kernel", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1, 1, 512, 512]}}, "('encoder', 'mid_block', 'attentions', '0', 'to_q', 'bias', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "attentions", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "to_q", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('encoder', 'mid_block', 'attentions', '0', 'to_q', 'kernel', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "attentions", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "to_q", "key_type": 2}, {"key": "kernel", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1, 1, 512, 512]}}, "('encoder', 'mid_block', 'attentions', '0', 'to_v', 'bias', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "attentions", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "to_v", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('encoder', 'mid_block', 'attentions', '0', 'to_v', 'kernel', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "attentions", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "to_v", "key_type": 2}, {"key": "kernel", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1, 1, 512, 512]}}, "('encoder', 'mid_block', 'resnets', '0', 'conv1', 'bias', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('encoder', 'mid_block', 'resnets', '0', 'conv1', 'kernel', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "kernel", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 512, 512]}}, "('encoder', 'mid_block', 'resnets', '0', 'conv2', 'bias', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('encoder', 'mid_block', 'resnets', '0', 'conv2', 'kernel', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "kernel", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 512, 512]}}, "('encoder', 'mid_block', 'resnets', '0', 'norm1', 'bias', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('encoder', 'mid_block', 'resnets', '0', 'norm1', 'scale', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "scale", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('encoder', 'mid_block', 'resnets', '0', 'norm2', 'bias', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('encoder', 'mid_block', 'resnets', '0', 'norm2', 'scale', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "scale", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('encoder', 'mid_block', 'resnets', '1', 'conv1', 'bias', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('encoder', 'mid_block', 'resnets', '1', 'conv1', 'kernel', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "kernel", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 512, 512]}}, "('encoder', 'mid_block', 'resnets', '1', 'conv2', 'bias', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('encoder', 'mid_block', 'resnets', '1', 'conv2', 'kernel', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "kernel", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 512, 512]}}, "('encoder', 'mid_block', 'resnets', '1', 'norm1', 'bias', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('encoder', 'mid_block', 'resnets', '1', 'norm1', 'scale', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "scale", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('encoder', 'mid_block', 'resnets', '1', 'norm2', 'bias', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "bias", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('encoder', 'mid_block', 'resnets', '1', 'norm2', 'scale', 'value')": {"key_metadata": [{"key": "encoder", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "resnets", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "scale", "key_type": 2}, {"key": "value", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}}, "use_zarr3": false, "store_array_data_equal_to_fill_value": true, "custom_metadata": null} |