| {"tree_metadata": {"('params', 'conv_in', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "conv_in", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'conv_in', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "conv_in", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 4, 320]}}, "('params', 'conv_norm_out', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "conv_norm_out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'conv_norm_out', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "conv_norm_out", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'conv_out', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "conv_out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [4]}}, "('params', 'conv_out', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "conv_out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 320, 4]}}, "('params', 'down_blocks_0', 'attentions_0', 'norm', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_0", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "norm", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'down_blocks_0', 'attentions_0', 'norm', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_0", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "norm", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'down_blocks_0', 'attentions_0', 'proj_in', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_0", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "proj_in", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'down_blocks_0', 'attentions_0', 'proj_in', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_0", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "proj_in", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1, 1, 320, 320]}}, "('params', 'down_blocks_0', 'attentions_0', 'proj_out', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_0", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "proj_out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'down_blocks_0', 'attentions_0', 'proj_out', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_0", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "proj_out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1, 1, 320, 320]}}, "('params', 'down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_k', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_0", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_k", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320, 320]}}, "('params', 'down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_out_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_0", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_out_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_out_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_0", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_out_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320, 320]}}, "('params', 'down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_q', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_0", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_q", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320, 320]}}, "('params', 'down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_v', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_0", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_v", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320, 320]}}, "('params', 'down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_k', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_0", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_k", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 320]}}, "('params', 'down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_out_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_0", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_out_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_out_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_0", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_out_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320, 320]}}, "('params', 'down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_q', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_0", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_q", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320, 320]}}, "('params', 'down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_v', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_0", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_v", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 320]}}, "('params', 'down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_0", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "ff", "key_type": 2}, {"key": "net_0", "key_type": 2}, {"key": "proj", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [2560]}}, "('params', 'down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_0", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "ff", "key_type": 2}, {"key": "net_0", "key_type": 2}, {"key": "proj", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320, 2560]}}, "('params', 'down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_0", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "ff", "key_type": 2}, {"key": "net_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_0", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "ff", "key_type": 2}, {"key": "net_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280, 320]}}, "('params', 'down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'norm1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_0", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'norm1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_0", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'norm2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_0", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'norm2', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_0", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'norm3', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_0", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm3", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'norm3', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_0", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm3", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'down_blocks_0', 'attentions_1', 'norm', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_0", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "norm", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'down_blocks_0', 'attentions_1', 'norm', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_0", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "norm", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'down_blocks_0', 'attentions_1', 'proj_in', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_0", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "proj_in", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'down_blocks_0', 'attentions_1', 'proj_in', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_0", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "proj_in", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1, 1, 320, 320]}}, "('params', 'down_blocks_0', 'attentions_1', 'proj_out', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_0", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "proj_out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'down_blocks_0', 'attentions_1', 'proj_out', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_0", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "proj_out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1, 1, 320, 320]}}, "('params', 'down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_k', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_0", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_k", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320, 320]}}, "('params', 'down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_out_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_0", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_out_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_out_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_0", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_out_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320, 320]}}, "('params', 'down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_q', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_0", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_q", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320, 320]}}, "('params', 'down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_v', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_0", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_v", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320, 320]}}, "('params', 'down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_k', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_0", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_k", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 320]}}, "('params', 'down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_out_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_0", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_out_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_out_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_0", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_out_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320, 320]}}, "('params', 'down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_q', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_0", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_q", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320, 320]}}, "('params', 'down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_v', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_0", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_v", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 320]}}, "('params', 'down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_0", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "ff", "key_type": 2}, {"key": "net_0", "key_type": 2}, {"key": "proj", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [2560]}}, "('params', 'down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_0", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "ff", "key_type": 2}, {"key": "net_0", "key_type": 2}, {"key": "proj", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320, 2560]}}, "('params', 'down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_0", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "ff", "key_type": 2}, {"key": "net_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_0", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "ff", "key_type": 2}, {"key": "net_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280, 320]}}, "('params', 'down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'norm1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_0", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'norm1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_0", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'norm2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_0", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'norm2', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_0", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'norm3', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_0", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm3", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'norm3', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_0", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm3", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'down_blocks_0', 'downsamplers_0', 'conv', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_0", "key_type": 2}, {"key": "downsamplers_0", "key_type": 2}, {"key": "conv", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'down_blocks_0', 'downsamplers_0', 'conv', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_0", "key_type": 2}, {"key": "downsamplers_0", "key_type": 2}, {"key": "conv", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 320, 320]}}, "('params', 'down_blocks_0', 'resnets_0', 'conv1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_0", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'down_blocks_0', 'resnets_0', 'conv1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_0", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 320, 320]}}, "('params', 'down_blocks_0', 'resnets_0', 'conv2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_0", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'down_blocks_0', 'resnets_0', 'conv2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_0", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 320, 320]}}, "('params', 'down_blocks_0', 'resnets_0', 'norm1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_0", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'down_blocks_0', 'resnets_0', 'norm1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_0", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'down_blocks_0', 'resnets_0', 'norm2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_0", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'down_blocks_0', 'resnets_0', 'norm2', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_0", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'down_blocks_0', 'resnets_0', 'time_emb_proj', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_0", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "time_emb_proj", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'down_blocks_0', 'resnets_0', 'time_emb_proj', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_0", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "time_emb_proj", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280, 320]}}, "('params', 'down_blocks_0', 'resnets_1', 'conv1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_0", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'down_blocks_0', 'resnets_1', 'conv1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_0", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 320, 320]}}, "('params', 'down_blocks_0', 'resnets_1', 'conv2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_0", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'down_blocks_0', 'resnets_1', 'conv2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_0", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 320, 320]}}, "('params', 'down_blocks_0', 'resnets_1', 'norm1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_0", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'down_blocks_0', 'resnets_1', 'norm1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_0", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'down_blocks_0', 'resnets_1', 'norm2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_0", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'down_blocks_0', 'resnets_1', 'norm2', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_0", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'down_blocks_0', 'resnets_1', 'time_emb_proj', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_0", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "time_emb_proj", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'down_blocks_0', 'resnets_1', 'time_emb_proj', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_0", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "time_emb_proj", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280, 320]}}, "('params', 'down_blocks_1', 'attentions_0', 'norm', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_1", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "norm", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'down_blocks_1', 'attentions_0', 'norm', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_1", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "norm", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'down_blocks_1', 'attentions_0', 'proj_in', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_1", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "proj_in", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'down_blocks_1', 'attentions_0', 'proj_in', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_1", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "proj_in", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1, 1, 640, 640]}}, "('params', 'down_blocks_1', 'attentions_0', 'proj_out', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_1", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "proj_out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'down_blocks_1', 'attentions_0', 'proj_out', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_1", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "proj_out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1, 1, 640, 640]}}, "('params', 'down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_k', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_1", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_k", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640, 640]}}, "('params', 'down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_out_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_1", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_out_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_out_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_1", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_out_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640, 640]}}, "('params', 'down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_q', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_1", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_q", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640, 640]}}, "('params', 'down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_v', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_1", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_v", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640, 640]}}, "('params', 'down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_k', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_1", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_k", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 640]}}, "('params', 'down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_out_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_1", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_out_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_out_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_1", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_out_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640, 640]}}, "('params', 'down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_q', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_1", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_q", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640, 640]}}, "('params', 'down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_v', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_1", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_v", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 640]}}, "('params', 'down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_1", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "ff", "key_type": 2}, {"key": "net_0", "key_type": 2}, {"key": "proj", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [5120]}}, "('params', 'down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_1", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "ff", "key_type": 2}, {"key": "net_0", "key_type": 2}, {"key": "proj", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640, 5120]}}, "('params', 'down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_1", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "ff", "key_type": 2}, {"key": "net_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_1", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "ff", "key_type": 2}, {"key": "net_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [2560, 640]}}, "('params', 'down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'norm1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_1", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'norm1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_1", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'norm2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_1", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'norm2', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_1", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'norm3', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_1", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm3", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'norm3', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_1", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm3", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'down_blocks_1', 'attentions_1', 'norm', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_1", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "norm", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'down_blocks_1', 'attentions_1', 'norm', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_1", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "norm", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'down_blocks_1', 'attentions_1', 'proj_in', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_1", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "proj_in", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'down_blocks_1', 'attentions_1', 'proj_in', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_1", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "proj_in", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1, 1, 640, 640]}}, "('params', 'down_blocks_1', 'attentions_1', 'proj_out', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_1", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "proj_out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'down_blocks_1', 'attentions_1', 'proj_out', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_1", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "proj_out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1, 1, 640, 640]}}, "('params', 'down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_k', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_1", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_k", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640, 640]}}, "('params', 'down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_out_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_1", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_out_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_out_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_1", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_out_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640, 640]}}, "('params', 'down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_q', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_1", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_q", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640, 640]}}, "('params', 'down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_v', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_1", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_v", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640, 640]}}, "('params', 'down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_k', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_1", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_k", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 640]}}, "('params', 'down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_out_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_1", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_out_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_out_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_1", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_out_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640, 640]}}, "('params', 'down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_q', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_1", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_q", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640, 640]}}, "('params', 'down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_v', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_1", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_v", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 640]}}, "('params', 'down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_1", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "ff", "key_type": 2}, {"key": "net_0", "key_type": 2}, {"key": "proj", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [5120]}}, "('params', 'down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_1", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "ff", "key_type": 2}, {"key": "net_0", "key_type": 2}, {"key": "proj", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640, 5120]}}, "('params', 'down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_1", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "ff", "key_type": 2}, {"key": "net_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_1", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "ff", "key_type": 2}, {"key": "net_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [2560, 640]}}, "('params', 'down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'norm1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_1", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'norm1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_1", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'norm2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_1", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'norm2', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_1", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'norm3', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_1", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm3", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'norm3', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_1", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm3", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'down_blocks_1', 'downsamplers_0', 'conv', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_1", "key_type": 2}, {"key": "downsamplers_0", "key_type": 2}, {"key": "conv", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'down_blocks_1', 'downsamplers_0', 'conv', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_1", "key_type": 2}, {"key": "downsamplers_0", "key_type": 2}, {"key": "conv", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 640, 640]}}, "('params', 'down_blocks_1', 'resnets_0', 'conv1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_1", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'down_blocks_1', 'resnets_0', 'conv1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_1", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 320, 640]}}, "('params', 'down_blocks_1', 'resnets_0', 'conv2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_1", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'down_blocks_1', 'resnets_0', 'conv2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_1", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 640, 640]}}, "('params', 'down_blocks_1', 'resnets_0', 'conv_shortcut', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_1", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "conv_shortcut", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'down_blocks_1', 'resnets_0', 'conv_shortcut', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_1", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "conv_shortcut", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1, 1, 320, 640]}}, "('params', 'down_blocks_1', 'resnets_0', 'norm1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_1", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'down_blocks_1', 'resnets_0', 'norm1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_1", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'down_blocks_1', 'resnets_0', 'norm2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_1", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'down_blocks_1', 'resnets_0', 'norm2', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_1", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'down_blocks_1', 'resnets_0', 'time_emb_proj', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_1", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "time_emb_proj", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'down_blocks_1', 'resnets_0', 'time_emb_proj', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_1", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "time_emb_proj", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280, 640]}}, "('params', 'down_blocks_1', 'resnets_1', 'conv1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_1", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'down_blocks_1', 'resnets_1', 'conv1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_1", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 640, 640]}}, "('params', 'down_blocks_1', 'resnets_1', 'conv2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_1", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'down_blocks_1', 'resnets_1', 'conv2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_1", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 640, 640]}}, "('params', 'down_blocks_1', 'resnets_1', 'norm1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_1", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'down_blocks_1', 'resnets_1', 'norm1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_1", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'down_blocks_1', 'resnets_1', 'norm2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_1", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'down_blocks_1', 'resnets_1', 'norm2', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_1", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'down_blocks_1', 'resnets_1', 'time_emb_proj', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_1", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "time_emb_proj", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'down_blocks_1', 'resnets_1', 'time_emb_proj', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_1", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "time_emb_proj", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280, 640]}}, "('params', 'down_blocks_2', 'attentions_0', 'norm', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_2", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "norm", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'down_blocks_2', 'attentions_0', 'norm', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_2", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "norm", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'down_blocks_2', 'attentions_0', 'proj_in', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_2", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "proj_in", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'down_blocks_2', 'attentions_0', 'proj_in', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_2", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "proj_in", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1, 1, 1280, 1280]}}, "('params', 'down_blocks_2', 'attentions_0', 'proj_out', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_2", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "proj_out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'down_blocks_2', 'attentions_0', 'proj_out', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_2", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "proj_out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1, 1, 1280, 1280]}}, "('params', 'down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_k', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_2", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_k", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280, 1280]}}, "('params', 'down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_out_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_2", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_out_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_out_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_2", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_out_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280, 1280]}}, "('params', 'down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_q', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_2", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_q", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280, 1280]}}, "('params', 'down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_v', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_2", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_v", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280, 1280]}}, "('params', 'down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_k', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_2", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_k", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 1280]}}, "('params', 'down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_out_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_2", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_out_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_out_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_2", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_out_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280, 1280]}}, "('params', 'down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_q', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_2", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_q", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280, 1280]}}, "('params', 'down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_v', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_2", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_v", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 1280]}}, "('params', 'down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_2", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "ff", "key_type": 2}, {"key": "net_0", "key_type": 2}, {"key": "proj", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [10240]}}, "('params', 'down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_2", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "ff", "key_type": 2}, {"key": "net_0", "key_type": 2}, {"key": "proj", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280, 10240]}}, "('params', 'down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_2", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "ff", "key_type": 2}, {"key": "net_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_2", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "ff", "key_type": 2}, {"key": "net_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [5120, 1280]}}, "('params', 'down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'norm1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_2", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'norm1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_2", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'norm2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_2", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'norm2', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_2", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'norm3', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_2", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm3", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'norm3', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_2", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm3", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'down_blocks_2', 'attentions_1', 'norm', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_2", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "norm", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'down_blocks_2', 'attentions_1', 'norm', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_2", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "norm", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'down_blocks_2', 'attentions_1', 'proj_in', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_2", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "proj_in", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'down_blocks_2', 'attentions_1', 'proj_in', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_2", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "proj_in", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1, 1, 1280, 1280]}}, "('params', 'down_blocks_2', 'attentions_1', 'proj_out', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_2", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "proj_out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'down_blocks_2', 'attentions_1', 'proj_out', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_2", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "proj_out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1, 1, 1280, 1280]}}, "('params', 'down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_k', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_2", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_k", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280, 1280]}}, "('params', 'down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_out_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_2", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_out_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_out_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_2", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_out_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280, 1280]}}, "('params', 'down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_q', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_2", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_q", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280, 1280]}}, "('params', 'down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_v', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_2", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_v", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280, 1280]}}, "('params', 'down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_k', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_2", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_k", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 1280]}}, "('params', 'down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_out_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_2", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_out_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_out_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_2", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_out_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280, 1280]}}, "('params', 'down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_q', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_2", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_q", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280, 1280]}}, "('params', 'down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_v', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_2", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_v", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 1280]}}, "('params', 'down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_2", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "ff", "key_type": 2}, {"key": "net_0", "key_type": 2}, {"key": "proj", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [10240]}}, "('params', 'down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_2", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "ff", "key_type": 2}, {"key": "net_0", "key_type": 2}, {"key": "proj", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280, 10240]}}, "('params', 'down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_2", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "ff", "key_type": 2}, {"key": "net_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_2", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "ff", "key_type": 2}, {"key": "net_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [5120, 1280]}}, "('params', 'down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'norm1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_2", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'norm1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_2", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'norm2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_2", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'norm2', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_2", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'norm3', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_2", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm3", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'norm3', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_2", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm3", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'down_blocks_2', 'downsamplers_0', 'conv', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_2", "key_type": 2}, {"key": "downsamplers_0", "key_type": 2}, {"key": "conv", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'down_blocks_2', 'downsamplers_0', 'conv', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_2", "key_type": 2}, {"key": "downsamplers_0", "key_type": 2}, {"key": "conv", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 1280, 1280]}}, "('params', 'down_blocks_2', 'resnets_0', 'conv1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_2", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'down_blocks_2', 'resnets_0', 'conv1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_2", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 640, 1280]}}, "('params', 'down_blocks_2', 'resnets_0', 'conv2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_2", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'down_blocks_2', 'resnets_0', 'conv2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_2", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 1280, 1280]}}, "('params', 'down_blocks_2', 'resnets_0', 'conv_shortcut', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_2", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "conv_shortcut", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'down_blocks_2', 'resnets_0', 'conv_shortcut', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_2", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "conv_shortcut", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1, 1, 640, 1280]}}, "('params', 'down_blocks_2', 'resnets_0', 'norm1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_2", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'down_blocks_2', 'resnets_0', 'norm1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_2", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'down_blocks_2', 'resnets_0', 'norm2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_2", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'down_blocks_2', 'resnets_0', 'norm2', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_2", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'down_blocks_2', 'resnets_0', 'time_emb_proj', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_2", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "time_emb_proj", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'down_blocks_2', 'resnets_0', 'time_emb_proj', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_2", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "time_emb_proj", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280, 1280]}}, "('params', 'down_blocks_2', 'resnets_1', 'conv1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_2", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'down_blocks_2', 'resnets_1', 'conv1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_2", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 1280, 1280]}}, "('params', 'down_blocks_2', 'resnets_1', 'conv2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_2", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'down_blocks_2', 'resnets_1', 'conv2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_2", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 1280, 1280]}}, "('params', 'down_blocks_2', 'resnets_1', 'norm1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_2", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'down_blocks_2', 'resnets_1', 'norm1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_2", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'down_blocks_2', 'resnets_1', 'norm2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_2", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'down_blocks_2', 'resnets_1', 'norm2', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_2", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'down_blocks_2', 'resnets_1', 'time_emb_proj', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_2", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "time_emb_proj", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'down_blocks_2', 'resnets_1', 'time_emb_proj', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_2", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "time_emb_proj", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280, 1280]}}, "('params', 'down_blocks_3', 'resnets_0', 'conv1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_3", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'down_blocks_3', 'resnets_0', 'conv1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_3", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 1280, 1280]}}, "('params', 'down_blocks_3', 'resnets_0', 'conv2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_3", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'down_blocks_3', 'resnets_0', 'conv2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_3", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 1280, 1280]}}, "('params', 'down_blocks_3', 'resnets_0', 'norm1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_3", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'down_blocks_3', 'resnets_0', 'norm1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_3", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'down_blocks_3', 'resnets_0', 'norm2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_3", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'down_blocks_3', 'resnets_0', 'norm2', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_3", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'down_blocks_3', 'resnets_0', 'time_emb_proj', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_3", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "time_emb_proj", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'down_blocks_3', 'resnets_0', 'time_emb_proj', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_3", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "time_emb_proj", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280, 1280]}}, "('params', 'down_blocks_3', 'resnets_1', 'conv1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_3", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'down_blocks_3', 'resnets_1', 'conv1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_3", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 1280, 1280]}}, "('params', 'down_blocks_3', 'resnets_1', 'conv2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_3", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'down_blocks_3', 'resnets_1', 'conv2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_3", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 1280, 1280]}}, "('params', 'down_blocks_3', 'resnets_1', 'norm1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_3", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'down_blocks_3', 'resnets_1', 'norm1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_3", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'down_blocks_3', 'resnets_1', 'norm2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_3", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'down_blocks_3', 'resnets_1', 'norm2', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_3", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'down_blocks_3', 'resnets_1', 'time_emb_proj', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_3", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "time_emb_proj", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'down_blocks_3', 'resnets_1', 'time_emb_proj', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "down_blocks_3", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "time_emb_proj", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280, 1280]}}, "('params', 'mid_block', 'attentions_0', 'norm', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "norm", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'mid_block', 'attentions_0', 'norm', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "norm", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'mid_block', 'attentions_0', 'proj_in', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "proj_in", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'mid_block', 'attentions_0', 'proj_in', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "proj_in", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1, 1, 1280, 1280]}}, "('params', 'mid_block', 'attentions_0', 'proj_out', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "proj_out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'mid_block', 'attentions_0', 'proj_out', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "proj_out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1, 1, 1280, 1280]}}, "('params', 'mid_block', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_k', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_k", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280, 1280]}}, "('params', 'mid_block', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_out_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_out_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'mid_block', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_out_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_out_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280, 1280]}}, "('params', 'mid_block', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_q', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_q", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280, 1280]}}, "('params', 'mid_block', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_v', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_v", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280, 1280]}}, "('params', 'mid_block', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_k', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_k", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 1280]}}, "('params', 'mid_block', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_out_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_out_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'mid_block', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_out_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_out_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280, 1280]}}, "('params', 'mid_block', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_q', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_q", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280, 1280]}}, "('params', 'mid_block', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_v', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_v", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 1280]}}, "('params', 'mid_block', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "ff", "key_type": 2}, {"key": "net_0", "key_type": 2}, {"key": "proj", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [10240]}}, "('params', 'mid_block', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "ff", "key_type": 2}, {"key": "net_0", "key_type": 2}, {"key": "proj", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280, 10240]}}, "('params', 'mid_block', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "ff", "key_type": 2}, {"key": "net_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'mid_block', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "ff", "key_type": 2}, {"key": "net_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [5120, 1280]}}, "('params', 'mid_block', 'attentions_0', 'transformer_blocks_0', 'norm1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'mid_block', 'attentions_0', 'transformer_blocks_0', 'norm1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'mid_block', 'attentions_0', 'transformer_blocks_0', 'norm2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'mid_block', 'attentions_0', 'transformer_blocks_0', 'norm2', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'mid_block', 'attentions_0', 'transformer_blocks_0', 'norm3', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm3", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'mid_block', 'attentions_0', 'transformer_blocks_0', 'norm3', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm3", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'mid_block', 'resnets_0', 'conv1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'mid_block', 'resnets_0', 'conv1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 1280, 1280]}}, "('params', 'mid_block', 'resnets_0', 'conv2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'mid_block', 'resnets_0', 'conv2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 1280, 1280]}}, "('params', 'mid_block', 'resnets_0', 'norm1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'mid_block', 'resnets_0', 'norm1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'mid_block', 'resnets_0', 'norm2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'mid_block', 'resnets_0', 'norm2', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'mid_block', 'resnets_0', 'time_emb_proj', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "time_emb_proj", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'mid_block', 'resnets_0', 'time_emb_proj', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "time_emb_proj", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280, 1280]}}, "('params', 'mid_block', 'resnets_1', 'conv1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'mid_block', 'resnets_1', 'conv1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 1280, 1280]}}, "('params', 'mid_block', 'resnets_1', 'conv2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'mid_block', 'resnets_1', 'conv2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 1280, 1280]}}, "('params', 'mid_block', 'resnets_1', 'norm1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'mid_block', 'resnets_1', 'norm1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'mid_block', 'resnets_1', 'norm2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'mid_block', 'resnets_1', 'norm2', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'mid_block', 'resnets_1', 'time_emb_proj', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "time_emb_proj", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'mid_block', 'resnets_1', 'time_emb_proj', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "mid_block", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "time_emb_proj", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280, 1280]}}, "('params', 'time_embedding', 'linear_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "time_embedding", "key_type": 2}, {"key": "linear_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'time_embedding', 'linear_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "time_embedding", "key_type": 2}, {"key": "linear_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320, 1280]}}, "('params', 'time_embedding', 'linear_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "time_embedding", "key_type": 2}, {"key": "linear_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'time_embedding', 'linear_2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "time_embedding", "key_type": 2}, {"key": "linear_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280, 1280]}}, "('params', 'up_blocks_0', 'resnets_0', 'conv1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_0", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'up_blocks_0', 'resnets_0', 'conv1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_0", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 2560, 1280]}}, "('params', 'up_blocks_0', 'resnets_0', 'conv2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_0", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'up_blocks_0', 'resnets_0', 'conv2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_0", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 1280, 1280]}}, "('params', 'up_blocks_0', 'resnets_0', 'conv_shortcut', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_0", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "conv_shortcut", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'up_blocks_0', 'resnets_0', 'conv_shortcut', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_0", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "conv_shortcut", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1, 1, 2560, 1280]}}, "('params', 'up_blocks_0', 'resnets_0', 'norm1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_0", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [2560]}}, "('params', 'up_blocks_0', 'resnets_0', 'norm1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_0", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [2560]}}, "('params', 'up_blocks_0', 'resnets_0', 'norm2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_0", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'up_blocks_0', 'resnets_0', 'norm2', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_0", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'up_blocks_0', 'resnets_0', 'time_emb_proj', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_0", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "time_emb_proj", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'up_blocks_0', 'resnets_0', 'time_emb_proj', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_0", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "time_emb_proj", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280, 1280]}}, "('params', 'up_blocks_0', 'resnets_1', 'conv1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_0", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'up_blocks_0', 'resnets_1', 'conv1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_0", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 2560, 1280]}}, "('params', 'up_blocks_0', 'resnets_1', 'conv2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_0", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'up_blocks_0', 'resnets_1', 'conv2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_0", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 1280, 1280]}}, "('params', 'up_blocks_0', 'resnets_1', 'conv_shortcut', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_0", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "conv_shortcut", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'up_blocks_0', 'resnets_1', 'conv_shortcut', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_0", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "conv_shortcut", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1, 1, 2560, 1280]}}, "('params', 'up_blocks_0', 'resnets_1', 'norm1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_0", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [2560]}}, "('params', 'up_blocks_0', 'resnets_1', 'norm1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_0", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [2560]}}, "('params', 'up_blocks_0', 'resnets_1', 'norm2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_0", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'up_blocks_0', 'resnets_1', 'norm2', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_0", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'up_blocks_0', 'resnets_1', 'time_emb_proj', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_0", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "time_emb_proj", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'up_blocks_0', 'resnets_1', 'time_emb_proj', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_0", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "time_emb_proj", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280, 1280]}}, "('params', 'up_blocks_0', 'resnets_2', 'conv1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_0", "key_type": 2}, {"key": "resnets_2", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'up_blocks_0', 'resnets_2', 'conv1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_0", "key_type": 2}, {"key": "resnets_2", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 2560, 1280]}}, "('params', 'up_blocks_0', 'resnets_2', 'conv2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_0", "key_type": 2}, {"key": "resnets_2", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'up_blocks_0', 'resnets_2', 'conv2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_0", "key_type": 2}, {"key": "resnets_2", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 1280, 1280]}}, "('params', 'up_blocks_0', 'resnets_2', 'conv_shortcut', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_0", "key_type": 2}, {"key": "resnets_2", "key_type": 2}, {"key": "conv_shortcut", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'up_blocks_0', 'resnets_2', 'conv_shortcut', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_0", "key_type": 2}, {"key": "resnets_2", "key_type": 2}, {"key": "conv_shortcut", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1, 1, 2560, 1280]}}, "('params', 'up_blocks_0', 'resnets_2', 'norm1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_0", "key_type": 2}, {"key": "resnets_2", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [2560]}}, "('params', 'up_blocks_0', 'resnets_2', 'norm1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_0", "key_type": 2}, {"key": "resnets_2", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [2560]}}, "('params', 'up_blocks_0', 'resnets_2', 'norm2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_0", "key_type": 2}, {"key": "resnets_2", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'up_blocks_0', 'resnets_2', 'norm2', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_0", "key_type": 2}, {"key": "resnets_2", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'up_blocks_0', 'resnets_2', 'time_emb_proj', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_0", "key_type": 2}, {"key": "resnets_2", "key_type": 2}, {"key": "time_emb_proj", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'up_blocks_0', 'resnets_2', 'time_emb_proj', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_0", "key_type": 2}, {"key": "resnets_2", "key_type": 2}, {"key": "time_emb_proj", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280, 1280]}}, "('params', 'up_blocks_0', 'upsamplers_0', 'conv', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_0", "key_type": 2}, {"key": "upsamplers_0", "key_type": 2}, {"key": "conv", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'up_blocks_0', 'upsamplers_0', 'conv', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_0", "key_type": 2}, {"key": "upsamplers_0", "key_type": 2}, {"key": "conv", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 1280, 1280]}}, "('params', 'up_blocks_1', 'attentions_0', 'norm', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "norm", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'up_blocks_1', 'attentions_0', 'norm', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "norm", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'up_blocks_1', 'attentions_0', 'proj_in', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "proj_in", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'up_blocks_1', 'attentions_0', 'proj_in', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "proj_in", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1, 1, 1280, 1280]}}, "('params', 'up_blocks_1', 'attentions_0', 'proj_out', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "proj_out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'up_blocks_1', 'attentions_0', 'proj_out', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "proj_out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1, 1, 1280, 1280]}}, "('params', 'up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_k', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_k", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280, 1280]}}, "('params', 'up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_out_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_out_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_out_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_out_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280, 1280]}}, "('params', 'up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_q', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_q", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280, 1280]}}, "('params', 'up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_v', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_v", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280, 1280]}}, "('params', 'up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_k', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_k", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 1280]}}, "('params', 'up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_out_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_out_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_out_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_out_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280, 1280]}}, "('params', 'up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_q', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_q", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280, 1280]}}, "('params', 'up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_v', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_v", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 1280]}}, "('params', 'up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "ff", "key_type": 2}, {"key": "net_0", "key_type": 2}, {"key": "proj", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [10240]}}, "('params', 'up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "ff", "key_type": 2}, {"key": "net_0", "key_type": 2}, {"key": "proj", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280, 10240]}}, "('params', 'up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "ff", "key_type": 2}, {"key": "net_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "ff", "key_type": 2}, {"key": "net_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [5120, 1280]}}, "('params', 'up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'norm1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'norm1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'norm2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'norm2', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'norm3', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm3", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'norm3', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm3", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'up_blocks_1', 'attentions_1', 'norm', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "norm", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'up_blocks_1', 'attentions_1', 'norm', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "norm", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'up_blocks_1', 'attentions_1', 'proj_in', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "proj_in", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'up_blocks_1', 'attentions_1', 'proj_in', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "proj_in", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1, 1, 1280, 1280]}}, "('params', 'up_blocks_1', 'attentions_1', 'proj_out', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "proj_out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'up_blocks_1', 'attentions_1', 'proj_out', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "proj_out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1, 1, 1280, 1280]}}, "('params', 'up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_k', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_k", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280, 1280]}}, "('params', 'up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_out_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_out_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_out_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_out_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280, 1280]}}, "('params', 'up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_q', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_q", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280, 1280]}}, "('params', 'up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_v', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_v", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280, 1280]}}, "('params', 'up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_k', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_k", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 1280]}}, "('params', 'up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_out_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_out_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_out_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_out_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280, 1280]}}, "('params', 'up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_q', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_q", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280, 1280]}}, "('params', 'up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_v', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_v", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 1280]}}, "('params', 'up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "ff", "key_type": 2}, {"key": "net_0", "key_type": 2}, {"key": "proj", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [10240]}}, "('params', 'up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "ff", "key_type": 2}, {"key": "net_0", "key_type": 2}, {"key": "proj", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280, 10240]}}, "('params', 'up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "ff", "key_type": 2}, {"key": "net_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "ff", "key_type": 2}, {"key": "net_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [5120, 1280]}}, "('params', 'up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'norm1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'norm1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'norm2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'norm2', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'norm3', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm3", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'norm3', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm3", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'up_blocks_1', 'attentions_2', 'norm', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "attentions_2", "key_type": 2}, {"key": "norm", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'up_blocks_1', 'attentions_2', 'norm', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "attentions_2", "key_type": 2}, {"key": "norm", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'up_blocks_1', 'attentions_2', 'proj_in', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "attentions_2", "key_type": 2}, {"key": "proj_in", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'up_blocks_1', 'attentions_2', 'proj_in', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "attentions_2", "key_type": 2}, {"key": "proj_in", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1, 1, 1280, 1280]}}, "('params', 'up_blocks_1', 'attentions_2', 'proj_out', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "attentions_2", "key_type": 2}, {"key": "proj_out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'up_blocks_1', 'attentions_2', 'proj_out', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "attentions_2", "key_type": 2}, {"key": "proj_out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1, 1, 1280, 1280]}}, "('params', 'up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'attn1', 'to_k', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "attentions_2", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_k", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280, 1280]}}, "('params', 'up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'attn1', 'to_out_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "attentions_2", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_out_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'attn1', 'to_out_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "attentions_2", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_out_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280, 1280]}}, "('params', 'up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'attn1', 'to_q', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "attentions_2", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_q", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280, 1280]}}, "('params', 'up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'attn1', 'to_v', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "attentions_2", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_v", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280, 1280]}}, "('params', 'up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'attn2', 'to_k', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "attentions_2", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_k", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 1280]}}, "('params', 'up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'attn2', 'to_out_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "attentions_2", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_out_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'attn2', 'to_out_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "attentions_2", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_out_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280, 1280]}}, "('params', 'up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'attn2', 'to_q', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "attentions_2", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_q", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280, 1280]}}, "('params', 'up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'attn2', 'to_v', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "attentions_2", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_v", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 1280]}}, "('params', 'up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "attentions_2", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "ff", "key_type": 2}, {"key": "net_0", "key_type": 2}, {"key": "proj", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [10240]}}, "('params', 'up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "attentions_2", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "ff", "key_type": 2}, {"key": "net_0", "key_type": 2}, {"key": "proj", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280, 10240]}}, "('params', 'up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'ff', 'net_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "attentions_2", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "ff", "key_type": 2}, {"key": "net_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'ff', 'net_2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "attentions_2", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "ff", "key_type": 2}, {"key": "net_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [5120, 1280]}}, "('params', 'up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'norm1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "attentions_2", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'norm1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "attentions_2", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'norm2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "attentions_2", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'norm2', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "attentions_2", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'norm3', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "attentions_2", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm3", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'norm3', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "attentions_2", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm3", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'up_blocks_1', 'resnets_0', 'conv1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'up_blocks_1', 'resnets_0', 'conv1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 2560, 1280]}}, "('params', 'up_blocks_1', 'resnets_0', 'conv2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'up_blocks_1', 'resnets_0', 'conv2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 1280, 1280]}}, "('params', 'up_blocks_1', 'resnets_0', 'conv_shortcut', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "conv_shortcut", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'up_blocks_1', 'resnets_0', 'conv_shortcut', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "conv_shortcut", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1, 1, 2560, 1280]}}, "('params', 'up_blocks_1', 'resnets_0', 'norm1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [2560]}}, "('params', 'up_blocks_1', 'resnets_0', 'norm1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [2560]}}, "('params', 'up_blocks_1', 'resnets_0', 'norm2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'up_blocks_1', 'resnets_0', 'norm2', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'up_blocks_1', 'resnets_0', 'time_emb_proj', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "time_emb_proj", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'up_blocks_1', 'resnets_0', 'time_emb_proj', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "time_emb_proj", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280, 1280]}}, "('params', 'up_blocks_1', 'resnets_1', 'conv1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'up_blocks_1', 'resnets_1', 'conv1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 2560, 1280]}}, "('params', 'up_blocks_1', 'resnets_1', 'conv2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'up_blocks_1', 'resnets_1', 'conv2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 1280, 1280]}}, "('params', 'up_blocks_1', 'resnets_1', 'conv_shortcut', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "conv_shortcut", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'up_blocks_1', 'resnets_1', 'conv_shortcut', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "conv_shortcut", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1, 1, 2560, 1280]}}, "('params', 'up_blocks_1', 'resnets_1', 'norm1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [2560]}}, "('params', 'up_blocks_1', 'resnets_1', 'norm1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [2560]}}, "('params', 'up_blocks_1', 'resnets_1', 'norm2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'up_blocks_1', 'resnets_1', 'norm2', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'up_blocks_1', 'resnets_1', 'time_emb_proj', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "time_emb_proj", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'up_blocks_1', 'resnets_1', 'time_emb_proj', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "time_emb_proj", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280, 1280]}}, "('params', 'up_blocks_1', 'resnets_2', 'conv1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "resnets_2", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'up_blocks_1', 'resnets_2', 'conv1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "resnets_2", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 1920, 1280]}}, "('params', 'up_blocks_1', 'resnets_2', 'conv2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "resnets_2", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'up_blocks_1', 'resnets_2', 'conv2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "resnets_2", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 1280, 1280]}}, "('params', 'up_blocks_1', 'resnets_2', 'conv_shortcut', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "resnets_2", "key_type": 2}, {"key": "conv_shortcut", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'up_blocks_1', 'resnets_2', 'conv_shortcut', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "resnets_2", "key_type": 2}, {"key": "conv_shortcut", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1, 1, 1920, 1280]}}, "('params', 'up_blocks_1', 'resnets_2', 'norm1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "resnets_2", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1920]}}, "('params', 'up_blocks_1', 'resnets_2', 'norm1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "resnets_2", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1920]}}, "('params', 'up_blocks_1', 'resnets_2', 'norm2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "resnets_2", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'up_blocks_1', 'resnets_2', 'norm2', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "resnets_2", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'up_blocks_1', 'resnets_2', 'time_emb_proj', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "resnets_2", "key_type": 2}, {"key": "time_emb_proj", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'up_blocks_1', 'resnets_2', 'time_emb_proj', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "resnets_2", "key_type": 2}, {"key": "time_emb_proj", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280, 1280]}}, "('params', 'up_blocks_1', 'upsamplers_0', 'conv', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "upsamplers_0", "key_type": 2}, {"key": "conv", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'up_blocks_1', 'upsamplers_0', 'conv', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_1", "key_type": 2}, {"key": "upsamplers_0", "key_type": 2}, {"key": "conv", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 1280, 1280]}}, "('params', 'up_blocks_2', 'attentions_0', 'norm', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "norm", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'up_blocks_2', 'attentions_0', 'norm', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "norm", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'up_blocks_2', 'attentions_0', 'proj_in', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "proj_in", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'up_blocks_2', 'attentions_0', 'proj_in', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "proj_in", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1, 1, 640, 640]}}, "('params', 'up_blocks_2', 'attentions_0', 'proj_out', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "proj_out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'up_blocks_2', 'attentions_0', 'proj_out', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "proj_out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1, 1, 640, 640]}}, "('params', 'up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_k', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_k", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640, 640]}}, "('params', 'up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_out_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_out_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_out_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_out_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640, 640]}}, "('params', 'up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_q', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_q", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640, 640]}}, "('params', 'up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_v', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_v", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640, 640]}}, "('params', 'up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_k', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_k", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 640]}}, "('params', 'up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_out_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_out_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_out_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_out_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640, 640]}}, "('params', 'up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_q', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_q", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640, 640]}}, "('params', 'up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_v', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_v", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 640]}}, "('params', 'up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "ff", "key_type": 2}, {"key": "net_0", "key_type": 2}, {"key": "proj", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [5120]}}, "('params', 'up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "ff", "key_type": 2}, {"key": "net_0", "key_type": 2}, {"key": "proj", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640, 5120]}}, "('params', 'up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "ff", "key_type": 2}, {"key": "net_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "ff", "key_type": 2}, {"key": "net_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [2560, 640]}}, "('params', 'up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'norm1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'norm1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'norm2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'norm2', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'norm3', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm3", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'norm3', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm3", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'up_blocks_2', 'attentions_1', 'norm', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "norm", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'up_blocks_2', 'attentions_1', 'norm', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "norm", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'up_blocks_2', 'attentions_1', 'proj_in', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "proj_in", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'up_blocks_2', 'attentions_1', 'proj_in', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "proj_in", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1, 1, 640, 640]}}, "('params', 'up_blocks_2', 'attentions_1', 'proj_out', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "proj_out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'up_blocks_2', 'attentions_1', 'proj_out', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "proj_out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1, 1, 640, 640]}}, "('params', 'up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_k', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_k", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640, 640]}}, "('params', 'up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_out_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_out_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_out_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_out_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640, 640]}}, "('params', 'up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_q', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_q", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640, 640]}}, "('params', 'up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_v', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_v", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640, 640]}}, "('params', 'up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_k', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_k", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 640]}}, "('params', 'up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_out_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_out_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_out_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_out_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640, 640]}}, "('params', 'up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_q', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_q", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640, 640]}}, "('params', 'up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_v', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_v", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 640]}}, "('params', 'up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "ff", "key_type": 2}, {"key": "net_0", "key_type": 2}, {"key": "proj", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [5120]}}, "('params', 'up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "ff", "key_type": 2}, {"key": "net_0", "key_type": 2}, {"key": "proj", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640, 5120]}}, "('params', 'up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "ff", "key_type": 2}, {"key": "net_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "ff", "key_type": 2}, {"key": "net_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [2560, 640]}}, "('params', 'up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'norm1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'norm1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'norm2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'norm2', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'norm3', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm3", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'norm3', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm3", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'up_blocks_2', 'attentions_2', 'norm', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "attentions_2", "key_type": 2}, {"key": "norm", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'up_blocks_2', 'attentions_2', 'norm', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "attentions_2", "key_type": 2}, {"key": "norm", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'up_blocks_2', 'attentions_2', 'proj_in', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "attentions_2", "key_type": 2}, {"key": "proj_in", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'up_blocks_2', 'attentions_2', 'proj_in', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "attentions_2", "key_type": 2}, {"key": "proj_in", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1, 1, 640, 640]}}, "('params', 'up_blocks_2', 'attentions_2', 'proj_out', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "attentions_2", "key_type": 2}, {"key": "proj_out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'up_blocks_2', 'attentions_2', 'proj_out', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "attentions_2", "key_type": 2}, {"key": "proj_out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1, 1, 640, 640]}}, "('params', 'up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'attn1', 'to_k', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "attentions_2", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_k", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640, 640]}}, "('params', 'up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'attn1', 'to_out_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "attentions_2", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_out_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'attn1', 'to_out_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "attentions_2", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_out_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640, 640]}}, "('params', 'up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'attn1', 'to_q', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "attentions_2", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_q", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640, 640]}}, "('params', 'up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'attn1', 'to_v', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "attentions_2", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_v", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640, 640]}}, "('params', 'up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'attn2', 'to_k', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "attentions_2", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_k", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 640]}}, "('params', 'up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'attn2', 'to_out_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "attentions_2", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_out_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'attn2', 'to_out_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "attentions_2", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_out_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640, 640]}}, "('params', 'up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'attn2', 'to_q', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "attentions_2", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_q", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640, 640]}}, "('params', 'up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'attn2', 'to_v', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "attentions_2", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_v", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 640]}}, "('params', 'up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "attentions_2", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "ff", "key_type": 2}, {"key": "net_0", "key_type": 2}, {"key": "proj", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [5120]}}, "('params', 'up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "attentions_2", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "ff", "key_type": 2}, {"key": "net_0", "key_type": 2}, {"key": "proj", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640, 5120]}}, "('params', 'up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'ff', 'net_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "attentions_2", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "ff", "key_type": 2}, {"key": "net_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'ff', 'net_2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "attentions_2", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "ff", "key_type": 2}, {"key": "net_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [2560, 640]}}, "('params', 'up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'norm1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "attentions_2", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'norm1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "attentions_2", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'norm2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "attentions_2", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'norm2', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "attentions_2", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'norm3', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "attentions_2", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm3", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'norm3', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "attentions_2", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm3", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'up_blocks_2', 'resnets_0', 'conv1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'up_blocks_2', 'resnets_0', 'conv1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 1920, 640]}}, "('params', 'up_blocks_2', 'resnets_0', 'conv2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'up_blocks_2', 'resnets_0', 'conv2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 640, 640]}}, "('params', 'up_blocks_2', 'resnets_0', 'conv_shortcut', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "conv_shortcut", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'up_blocks_2', 'resnets_0', 'conv_shortcut', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "conv_shortcut", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1, 1, 1920, 640]}}, "('params', 'up_blocks_2', 'resnets_0', 'norm1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1920]}}, "('params', 'up_blocks_2', 'resnets_0', 'norm1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1920]}}, "('params', 'up_blocks_2', 'resnets_0', 'norm2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'up_blocks_2', 'resnets_0', 'norm2', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'up_blocks_2', 'resnets_0', 'time_emb_proj', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "time_emb_proj", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'up_blocks_2', 'resnets_0', 'time_emb_proj', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "time_emb_proj", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280, 640]}}, "('params', 'up_blocks_2', 'resnets_1', 'conv1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'up_blocks_2', 'resnets_1', 'conv1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 1280, 640]}}, "('params', 'up_blocks_2', 'resnets_1', 'conv2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'up_blocks_2', 'resnets_1', 'conv2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 640, 640]}}, "('params', 'up_blocks_2', 'resnets_1', 'conv_shortcut', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "conv_shortcut", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'up_blocks_2', 'resnets_1', 'conv_shortcut', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "conv_shortcut", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1, 1, 1280, 640]}}, "('params', 'up_blocks_2', 'resnets_1', 'norm1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'up_blocks_2', 'resnets_1', 'norm1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280]}}, "('params', 'up_blocks_2', 'resnets_1', 'norm2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'up_blocks_2', 'resnets_1', 'norm2', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'up_blocks_2', 'resnets_1', 'time_emb_proj', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "time_emb_proj", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'up_blocks_2', 'resnets_1', 'time_emb_proj', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "time_emb_proj", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280, 640]}}, "('params', 'up_blocks_2', 'resnets_2', 'conv1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "resnets_2", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'up_blocks_2', 'resnets_2', 'conv1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "resnets_2", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 960, 640]}}, "('params', 'up_blocks_2', 'resnets_2', 'conv2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "resnets_2", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'up_blocks_2', 'resnets_2', 'conv2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "resnets_2", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 640, 640]}}, "('params', 'up_blocks_2', 'resnets_2', 'conv_shortcut', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "resnets_2", "key_type": 2}, {"key": "conv_shortcut", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'up_blocks_2', 'resnets_2', 'conv_shortcut', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "resnets_2", "key_type": 2}, {"key": "conv_shortcut", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1, 1, 960, 640]}}, "('params', 'up_blocks_2', 'resnets_2', 'norm1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "resnets_2", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [960]}}, "('params', 'up_blocks_2', 'resnets_2', 'norm1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "resnets_2", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [960]}}, "('params', 'up_blocks_2', 'resnets_2', 'norm2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "resnets_2", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'up_blocks_2', 'resnets_2', 'norm2', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "resnets_2", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'up_blocks_2', 'resnets_2', 'time_emb_proj', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "resnets_2", "key_type": 2}, {"key": "time_emb_proj", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'up_blocks_2', 'resnets_2', 'time_emb_proj', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "resnets_2", "key_type": 2}, {"key": "time_emb_proj", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280, 640]}}, "('params', 'up_blocks_2', 'upsamplers_0', 'conv', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "upsamplers_0", "key_type": 2}, {"key": "conv", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'up_blocks_2', 'upsamplers_0', 'conv', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_2", "key_type": 2}, {"key": "upsamplers_0", "key_type": 2}, {"key": "conv", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 640, 640]}}, "('params', 'up_blocks_3', 'attentions_0', 'norm', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "norm", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'up_blocks_3', 'attentions_0', 'norm', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "norm", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'up_blocks_3', 'attentions_0', 'proj_in', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "proj_in", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'up_blocks_3', 'attentions_0', 'proj_in', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "proj_in", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1, 1, 320, 320]}}, "('params', 'up_blocks_3', 'attentions_0', 'proj_out', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "proj_out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'up_blocks_3', 'attentions_0', 'proj_out', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "proj_out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1, 1, 320, 320]}}, "('params', 'up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_k', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_k", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320, 320]}}, "('params', 'up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_out_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_out_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_out_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_out_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320, 320]}}, "('params', 'up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_q', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_q", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320, 320]}}, "('params', 'up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_v', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_v", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320, 320]}}, "('params', 'up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_k', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_k", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 320]}}, "('params', 'up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_out_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_out_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_out_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_out_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320, 320]}}, "('params', 'up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_q', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_q", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320, 320]}}, "('params', 'up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_v', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_v", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 320]}}, "('params', 'up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "ff", "key_type": 2}, {"key": "net_0", "key_type": 2}, {"key": "proj", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [2560]}}, "('params', 'up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "ff", "key_type": 2}, {"key": "net_0", "key_type": 2}, {"key": "proj", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320, 2560]}}, "('params', 'up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "ff", "key_type": 2}, {"key": "net_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "ff", "key_type": 2}, {"key": "net_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280, 320]}}, "('params', 'up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'norm1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'norm1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'norm2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'norm2', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'norm3', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm3", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'norm3', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "attentions_0", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm3", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'up_blocks_3', 'attentions_1', 'norm', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "norm", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'up_blocks_3', 'attentions_1', 'norm', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "norm", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'up_blocks_3', 'attentions_1', 'proj_in', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "proj_in", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'up_blocks_3', 'attentions_1', 'proj_in', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "proj_in", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1, 1, 320, 320]}}, "('params', 'up_blocks_3', 'attentions_1', 'proj_out', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "proj_out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'up_blocks_3', 'attentions_1', 'proj_out', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "proj_out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1, 1, 320, 320]}}, "('params', 'up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_k', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_k", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320, 320]}}, "('params', 'up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_out_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_out_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_out_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_out_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320, 320]}}, "('params', 'up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_q', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_q", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320, 320]}}, "('params', 'up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_v', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_v", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320, 320]}}, "('params', 'up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_k', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_k", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 320]}}, "('params', 'up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_out_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_out_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_out_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_out_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320, 320]}}, "('params', 'up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_q', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_q", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320, 320]}}, "('params', 'up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_v', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_v", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 320]}}, "('params', 'up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "ff", "key_type": 2}, {"key": "net_0", "key_type": 2}, {"key": "proj", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [2560]}}, "('params', 'up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "ff", "key_type": 2}, {"key": "net_0", "key_type": 2}, {"key": "proj", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320, 2560]}}, "('params', 'up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "ff", "key_type": 2}, {"key": "net_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "ff", "key_type": 2}, {"key": "net_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280, 320]}}, "('params', 'up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'norm1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'norm1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'norm2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'norm2', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'norm3', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm3", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'norm3', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "attentions_1", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm3", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'up_blocks_3', 'attentions_2', 'norm', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "attentions_2", "key_type": 2}, {"key": "norm", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'up_blocks_3', 'attentions_2', 'norm', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "attentions_2", "key_type": 2}, {"key": "norm", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'up_blocks_3', 'attentions_2', 'proj_in', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "attentions_2", "key_type": 2}, {"key": "proj_in", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'up_blocks_3', 'attentions_2', 'proj_in', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "attentions_2", "key_type": 2}, {"key": "proj_in", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1, 1, 320, 320]}}, "('params', 'up_blocks_3', 'attentions_2', 'proj_out', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "attentions_2", "key_type": 2}, {"key": "proj_out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'up_blocks_3', 'attentions_2', 'proj_out', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "attentions_2", "key_type": 2}, {"key": "proj_out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1, 1, 320, 320]}}, "('params', 'up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'attn1', 'to_k', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "attentions_2", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_k", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320, 320]}}, "('params', 'up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'attn1', 'to_out_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "attentions_2", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_out_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'attn1', 'to_out_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "attentions_2", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_out_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320, 320]}}, "('params', 'up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'attn1', 'to_q', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "attentions_2", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_q", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320, 320]}}, "('params', 'up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'attn1', 'to_v', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "attentions_2", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn1", "key_type": 2}, {"key": "to_v", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320, 320]}}, "('params', 'up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'attn2', 'to_k', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "attentions_2", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_k", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 320]}}, "('params', 'up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'attn2', 'to_out_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "attentions_2", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_out_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'attn2', 'to_out_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "attentions_2", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_out_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320, 320]}}, "('params', 'up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'attn2', 'to_q', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "attentions_2", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_q", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320, 320]}}, "('params', 'up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'attn2', 'to_v', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "attentions_2", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "attn2", "key_type": 2}, {"key": "to_v", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 320]}}, "('params', 'up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "attentions_2", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "ff", "key_type": 2}, {"key": "net_0", "key_type": 2}, {"key": "proj", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [2560]}}, "('params', 'up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "attentions_2", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "ff", "key_type": 2}, {"key": "net_0", "key_type": 2}, {"key": "proj", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320, 2560]}}, "('params', 'up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'ff', 'net_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "attentions_2", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "ff", "key_type": 2}, {"key": "net_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'ff', 'net_2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "attentions_2", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "ff", "key_type": 2}, {"key": "net_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280, 320]}}, "('params', 'up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'norm1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "attentions_2", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'norm1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "attentions_2", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'norm2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "attentions_2", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'norm2', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "attentions_2", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'norm3', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "attentions_2", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm3", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'norm3', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "attentions_2", "key_type": 2}, {"key": "transformer_blocks_0", "key_type": 2}, {"key": "norm3", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'up_blocks_3', 'resnets_0', 'conv1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'up_blocks_3', 'resnets_0', 'conv1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 960, 320]}}, "('params', 'up_blocks_3', 'resnets_0', 'conv2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'up_blocks_3', 'resnets_0', 'conv2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 320, 320]}}, "('params', 'up_blocks_3', 'resnets_0', 'conv_shortcut', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "conv_shortcut", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'up_blocks_3', 'resnets_0', 'conv_shortcut', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "conv_shortcut", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1, 1, 960, 320]}}, "('params', 'up_blocks_3', 'resnets_0', 'norm1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [960]}}, "('params', 'up_blocks_3', 'resnets_0', 'norm1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [960]}}, "('params', 'up_blocks_3', 'resnets_0', 'norm2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'up_blocks_3', 'resnets_0', 'norm2', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'up_blocks_3', 'resnets_0', 'time_emb_proj', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "time_emb_proj", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'up_blocks_3', 'resnets_0', 'time_emb_proj', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "resnets_0", "key_type": 2}, {"key": "time_emb_proj", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280, 320]}}, "('params', 'up_blocks_3', 'resnets_1', 'conv1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'up_blocks_3', 'resnets_1', 'conv1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 640, 320]}}, "('params', 'up_blocks_3', 'resnets_1', 'conv2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'up_blocks_3', 'resnets_1', 'conv2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 320, 320]}}, "('params', 'up_blocks_3', 'resnets_1', 'conv_shortcut', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "conv_shortcut", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'up_blocks_3', 'resnets_1', 'conv_shortcut', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "conv_shortcut", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1, 1, 640, 320]}}, "('params', 'up_blocks_3', 'resnets_1', 'norm1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'up_blocks_3', 'resnets_1', 'norm1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'up_blocks_3', 'resnets_1', 'norm2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'up_blocks_3', 'resnets_1', 'norm2', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'up_blocks_3', 'resnets_1', 'time_emb_proj', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "time_emb_proj", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'up_blocks_3', 'resnets_1', 'time_emb_proj', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "resnets_1", "key_type": 2}, {"key": "time_emb_proj", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280, 320]}}, "('params', 'up_blocks_3', 'resnets_2', 'conv1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "resnets_2", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'up_blocks_3', 'resnets_2', 'conv1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "resnets_2", "key_type": 2}, {"key": "conv1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 640, 320]}}, "('params', 'up_blocks_3', 'resnets_2', 'conv2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "resnets_2", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'up_blocks_3', 'resnets_2', 'conv2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "resnets_2", "key_type": 2}, {"key": "conv2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [3, 3, 320, 320]}}, "('params', 'up_blocks_3', 'resnets_2', 'conv_shortcut', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "resnets_2", "key_type": 2}, {"key": "conv_shortcut", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'up_blocks_3', 'resnets_2', 'conv_shortcut', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "resnets_2", "key_type": 2}, {"key": "conv_shortcut", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1, 1, 640, 320]}}, "('params', 'up_blocks_3', 'resnets_2', 'norm1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "resnets_2", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'up_blocks_3', 'resnets_2', 'norm1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "resnets_2", "key_type": 2}, {"key": "norm1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [640]}}, "('params', 'up_blocks_3', 'resnets_2', 'norm2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "resnets_2", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'up_blocks_3', 'resnets_2', 'norm2', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "resnets_2", "key_type": 2}, {"key": "norm2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'up_blocks_3', 'resnets_2', 'time_emb_proj', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "resnets_2", "key_type": 2}, {"key": "time_emb_proj", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [320]}}, "('params', 'up_blocks_3', 'resnets_2', 'time_emb_proj', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "up_blocks_3", "key_type": 2}, {"key": "resnets_2", "key_type": 2}, {"key": "time_emb_proj", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1280, 320]}}}, "use_ocdbt": true, "use_zarr3": false, "store_array_data_equal_to_fill_value": true, "custom_metadata": null} |