{"tree_metadata": {"('opt_state', '0')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}], "value_metadata": {"value_type": "None", "skip_deserialize": true}}, "('opt_state', '1', '0', 'count')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "count", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'edge_emb', 'embedding')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "edge_emb", "key_type": 2}, {"key": "embedding", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'head', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "head", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'head', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "head", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'lambda_emb', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "lambda_emb", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'lambda_emb', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "lambda_emb", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'lambda_emb', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "lambda_emb", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'lambda_emb', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "lambda_emb", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_0', 'attn', 'e_add', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_add", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_0', 'attn', 'e_add', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_add", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_0', 'attn', 'e_mul', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_mul", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_0', 'attn', 'e_mul', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_mul", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_0', 'attn', 'e_out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_0', 'attn', 'e_out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_0', 'attn', 'e_to_y', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_to_y", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_0', 'attn', 'e_to_y', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_to_y", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_0', 'attn', 'k', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "k", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_0', 'attn', 'k', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "k", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_0', 'attn', 'q', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "q", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_0', 'attn', 'q', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "q", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_0', 'attn', 'v', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "v", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_0', 'attn', 'v', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "v", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_0', 'attn', 'x_out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_0', 'attn', 'x_out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_0', 'attn', 'x_to_y', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_to_y", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_0', 'attn', 'x_to_y', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_to_y", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_0', 'attn', 'y_e_add', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_add", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_0', 'attn', 'y_e_add', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_add", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_0', 'attn', 'y_e_mul', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_mul", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_0', 'attn', 'y_e_mul', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_mul", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_0', 'attn', 'y_out_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_0', 'attn', 'y_out_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_0', 'attn', 'y_out_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_0', 'attn', 'y_out_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_0', 'attn', 'y_x_add', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_add", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_0', 'attn', 'y_x_add', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_add", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_0', 'attn', 'y_x_mul', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_mul", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_0', 'attn', 'y_x_mul', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_mul", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_0', 'attn', 'y_y', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_y", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_0', 'attn', 'y_y', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_y", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_0', 'ffn_e_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "ffn_e_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_0', 'ffn_e_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "ffn_e_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_0', 'ffn_e_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "ffn_e_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_0', 'ffn_e_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "ffn_e_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_0', 'ffn_x_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "ffn_x_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_0', 'ffn_x_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "ffn_x_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_0', 'ffn_x_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "ffn_x_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_0', 'ffn_x_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "ffn_x_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_0', 'ffn_y_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "ffn_y_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_0', 'ffn_y_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "ffn_y_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_0', 'ffn_y_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "ffn_y_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_0', 'ffn_y_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "ffn_y_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_0', 'film_l_e', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "film_l_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_0', 'film_l_e', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "film_l_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_0', 'film_l_x', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "film_l_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_0', 'film_l_x', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "film_l_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_0', 'film_t_e', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "film_t_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_0', 'film_t_e', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "film_t_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_0', 'film_t_x', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "film_t_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_0', 'film_t_x', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "film_t_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_0', 'film_y_e', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "film_y_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_0', 'film_y_e', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "film_y_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_0', 'film_y_x', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "film_y_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_0', 'film_y_x', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "film_y_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_0', 'norm_e_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "norm_e_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_0', 'norm_e_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "norm_e_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_0', 'norm_e_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "norm_e_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_0', 'norm_e_2', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "norm_e_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_0', 'norm_x_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "norm_x_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_0', 'norm_x_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "norm_x_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_0', 'norm_x_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "norm_x_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_0', 'norm_x_2', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "norm_x_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_0', 'norm_y_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "norm_y_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_0', 'norm_y_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "norm_y_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_0', 'norm_y_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "norm_y_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_0', 'norm_y_2', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "norm_y_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_1', 'attn', 'e_add', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_add", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_1', 'attn', 'e_add', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_add", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_1', 'attn', 'e_mul', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_mul", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_1', 'attn', 'e_mul', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_mul", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_1', 'attn', 'e_out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_1', 'attn', 'e_out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_1', 'attn', 'e_to_y', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_to_y", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_1', 'attn', 'e_to_y', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_to_y", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_1', 'attn', 'k', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "k", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_1', 'attn', 'k', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "k", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_1', 'attn', 'q', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "q", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_1', 'attn', 'q', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "q", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_1', 'attn', 'v', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "v", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_1', 'attn', 'v', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "v", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_1', 'attn', 'x_out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_1', 'attn', 'x_out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_1', 'attn', 'x_to_y', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_to_y", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_1', 'attn', 'x_to_y', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_to_y", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_1', 'attn', 'y_e_add', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_add", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_1', 'attn', 'y_e_add', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_add", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_1', 'attn', 'y_e_mul', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_mul", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_1', 'attn', 'y_e_mul', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_mul", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_1', 'attn', 'y_out_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_1', 'attn', 'y_out_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_1', 'attn', 'y_out_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_1', 'attn', 'y_out_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_1', 'attn', 'y_x_add', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_add", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_1', 'attn', 'y_x_add', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_add", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_1', 'attn', 'y_x_mul', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_mul", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_1', 'attn', 'y_x_mul', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_mul", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_1', 'attn', 'y_y', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_y", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_1', 'attn', 'y_y', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_y", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_1', 'ffn_e_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "ffn_e_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_1', 'ffn_e_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "ffn_e_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_1', 'ffn_e_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "ffn_e_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_1', 'ffn_e_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "ffn_e_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_1', 'ffn_x_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "ffn_x_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_1', 'ffn_x_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "ffn_x_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_1', 'ffn_x_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "ffn_x_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_1', 'ffn_x_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "ffn_x_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_1', 'ffn_y_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "ffn_y_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_1', 'ffn_y_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "ffn_y_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_1', 'ffn_y_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "ffn_y_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_1', 'ffn_y_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "ffn_y_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_1', 'film_l_e', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "film_l_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_1', 'film_l_e', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "film_l_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_1', 'film_l_x', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "film_l_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_1', 'film_l_x', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "film_l_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_1', 'film_t_e', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "film_t_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_1', 'film_t_e', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "film_t_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_1', 'film_t_x', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "film_t_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_1', 'film_t_x', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "film_t_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_1', 'film_y_e', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "film_y_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_1', 'film_y_e', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "film_y_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_1', 'film_y_x', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "film_y_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_1', 'film_y_x', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "film_y_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_1', 'norm_e_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "norm_e_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_1', 'norm_e_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "norm_e_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_1', 'norm_e_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "norm_e_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_1', 'norm_e_2', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "norm_e_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_1', 'norm_x_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "norm_x_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_1', 'norm_x_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "norm_x_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_1', 'norm_x_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "norm_x_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_1', 'norm_x_2', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "norm_x_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_1', 'norm_y_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "norm_y_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_1', 'norm_y_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "norm_y_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_1', 'norm_y_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "norm_y_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_1', 'norm_y_2', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "norm_y_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_2', 'attn', 'e_add', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_add", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_2', 'attn', 'e_add', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_add", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_2', 'attn', 'e_mul', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_mul", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_2', 'attn', 'e_mul', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_mul", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_2', 'attn', 'e_out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_2', 'attn', 'e_out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_2', 'attn', 'e_to_y', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_to_y", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_2', 'attn', 'e_to_y', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_to_y", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_2', 'attn', 'k', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "k", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_2', 'attn', 'k', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "k", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_2', 'attn', 'q', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "q", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_2', 'attn', 'q', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "q", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_2', 'attn', 'v', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "v", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_2', 'attn', 'v', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "v", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_2', 'attn', 'x_out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_2', 'attn', 'x_out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_2', 'attn', 'x_to_y', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_to_y", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_2', 'attn', 'x_to_y', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_to_y", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_2', 'attn', 'y_e_add', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_add", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_2', 'attn', 'y_e_add', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_add", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_2', 'attn', 'y_e_mul', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_mul", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_2', 'attn', 'y_e_mul', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_mul", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_2', 'attn', 'y_out_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_2', 'attn', 'y_out_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_2', 'attn', 'y_out_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_2', 'attn', 'y_out_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_2', 'attn', 'y_x_add', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_add", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_2', 'attn', 'y_x_add', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_add", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_2', 'attn', 'y_x_mul', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_mul", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_2', 'attn', 'y_x_mul', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_mul", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_2', 'attn', 'y_y', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_y", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_2', 'attn', 'y_y', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_y", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_2', 'ffn_e_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "ffn_e_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_2', 'ffn_e_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "ffn_e_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_2', 'ffn_e_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "ffn_e_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_2', 'ffn_e_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "ffn_e_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_2', 'ffn_x_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "ffn_x_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_2', 'ffn_x_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "ffn_x_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_2', 'ffn_x_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "ffn_x_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_2', 'ffn_x_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "ffn_x_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_2', 'ffn_y_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "ffn_y_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_2', 'ffn_y_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "ffn_y_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_2', 'ffn_y_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "ffn_y_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_2', 'ffn_y_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "ffn_y_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_2', 'film_l_e', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "film_l_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_2', 'film_l_e', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "film_l_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_2', 'film_l_x', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "film_l_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_2', 'film_l_x', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "film_l_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_2', 'film_t_e', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "film_t_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_2', 'film_t_e', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "film_t_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_2', 'film_t_x', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "film_t_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_2', 'film_t_x', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "film_t_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_2', 'film_y_e', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "film_y_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_2', 'film_y_e', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "film_y_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_2', 'film_y_x', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "film_y_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_2', 'film_y_x', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "film_y_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_2', 'norm_e_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "norm_e_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_2', 'norm_e_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "norm_e_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_2', 'norm_e_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "norm_e_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_2', 'norm_e_2', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "norm_e_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_2', 'norm_x_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "norm_x_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_2', 'norm_x_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "norm_x_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_2', 'norm_x_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "norm_x_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_2', 'norm_x_2', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "norm_x_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_2', 'norm_y_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "norm_y_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_2', 'norm_y_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "norm_y_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_2', 'norm_y_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "norm_y_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_2', 'norm_y_2', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "norm_y_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_3', 'attn', 'e_add', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_add", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_3', 'attn', 'e_add', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_add", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_3', 'attn', 'e_mul', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_mul", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_3', 'attn', 'e_mul', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_mul", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_3', 'attn', 'e_out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_3', 'attn', 'e_out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_3', 'attn', 'e_to_y', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_to_y", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_3', 'attn', 'e_to_y', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_to_y", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_3', 'attn', 'k', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "k", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_3', 'attn', 'k', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "k", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_3', 'attn', 'q', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "q", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_3', 'attn', 'q', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "q", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_3', 'attn', 'v', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "v", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_3', 'attn', 'v', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "v", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_3', 'attn', 'x_out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_3', 'attn', 'x_out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_3', 'attn', 'x_to_y', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_to_y", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_3', 'attn', 'x_to_y', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_to_y", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_3', 'attn', 'y_e_add', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_add", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_3', 'attn', 'y_e_add', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_add", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_3', 'attn', 'y_e_mul', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_mul", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_3', 'attn', 'y_e_mul', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_mul", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_3', 'attn', 'y_out_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_3', 'attn', 'y_out_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_3', 'attn', 'y_out_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_3', 'attn', 'y_out_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_3', 'attn', 'y_x_add', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_add", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_3', 'attn', 'y_x_add', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_add", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_3', 'attn', 'y_x_mul', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_mul", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_3', 'attn', 'y_x_mul', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_mul", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_3', 'attn', 'y_y', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_y", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_3', 'attn', 'y_y', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_y", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_3', 'ffn_e_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "ffn_e_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_3', 'ffn_e_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "ffn_e_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_3', 'ffn_e_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "ffn_e_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_3', 'ffn_e_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "ffn_e_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_3', 'ffn_x_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "ffn_x_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_3', 'ffn_x_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "ffn_x_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_3', 'ffn_x_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "ffn_x_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_3', 'ffn_x_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "ffn_x_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_3', 'ffn_y_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "ffn_y_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_3', 'ffn_y_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "ffn_y_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_3', 'ffn_y_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "ffn_y_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_3', 'ffn_y_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "ffn_y_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_3', 'film_l_e', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "film_l_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_3', 'film_l_e', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "film_l_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_3', 'film_l_x', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "film_l_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_3', 'film_l_x', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "film_l_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_3', 'film_t_e', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "film_t_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_3', 'film_t_e', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "film_t_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_3', 'film_t_x', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "film_t_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_3', 'film_t_x', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "film_t_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_3', 'film_y_e', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "film_y_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_3', 'film_y_e', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "film_y_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_3', 'film_y_x', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "film_y_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_3', 'film_y_x', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "film_y_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_3', 'norm_e_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "norm_e_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_3', 'norm_e_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "norm_e_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_3', 'norm_e_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "norm_e_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_3', 'norm_e_2', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "norm_e_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_3', 'norm_x_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "norm_x_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_3', 'norm_x_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "norm_x_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_3', 'norm_x_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "norm_x_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_3', 'norm_x_2', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "norm_x_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_3', 'norm_y_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "norm_y_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_3', 'norm_y_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "norm_y_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_3', 'norm_y_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "norm_y_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_3', 'norm_y_2', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "norm_y_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_4', 'attn', 'e_add', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_add", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_4', 'attn', 'e_add', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_add", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_4', 'attn', 'e_mul', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_mul", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_4', 'attn', 'e_mul', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_mul", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_4', 'attn', 'e_out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_4', 'attn', 'e_out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_4', 'attn', 'e_to_y', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_to_y", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_4', 'attn', 'e_to_y', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_to_y", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_4', 'attn', 'k', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "k", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_4', 'attn', 'k', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "k", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_4', 'attn', 'q', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "q", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_4', 'attn', 'q', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "q", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_4', 'attn', 'v', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "v", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_4', 'attn', 'v', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "v", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_4', 'attn', 'x_out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_4', 'attn', 'x_out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_4', 'attn', 'x_to_y', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_to_y", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_4', 'attn', 'x_to_y', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_to_y", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_4', 'attn', 'y_e_add', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_add", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_4', 'attn', 'y_e_add', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_add", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_4', 'attn', 'y_e_mul', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_mul", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_4', 'attn', 'y_e_mul', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_mul", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_4', 'attn', 'y_out_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_4', 'attn', 'y_out_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_4', 'attn', 'y_out_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_4', 'attn', 'y_out_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_4', 'attn', 'y_x_add', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_add", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_4', 'attn', 'y_x_add', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_add", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_4', 'attn', 'y_x_mul', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_mul", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_4', 'attn', 'y_x_mul', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_mul", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_4', 'attn', 'y_y', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_y", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_4', 'attn', 'y_y', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_y", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_4', 'ffn_e_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "ffn_e_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_4', 'ffn_e_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "ffn_e_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_4', 'ffn_e_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "ffn_e_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_4', 'ffn_e_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "ffn_e_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_4', 'ffn_x_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "ffn_x_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_4', 'ffn_x_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "ffn_x_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_4', 'ffn_x_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "ffn_x_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_4', 'ffn_x_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "ffn_x_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_4', 'ffn_y_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "ffn_y_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_4', 'ffn_y_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "ffn_y_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_4', 'ffn_y_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "ffn_y_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_4', 'ffn_y_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "ffn_y_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_4', 'film_l_e', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "film_l_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_4', 'film_l_e', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "film_l_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_4', 'film_l_x', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "film_l_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_4', 'film_l_x', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "film_l_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_4', 'film_t_e', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "film_t_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_4', 'film_t_e', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "film_t_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_4', 'film_t_x', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "film_t_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_4', 'film_t_x', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "film_t_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_4', 'film_y_e', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "film_y_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_4', 'film_y_e', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "film_y_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_4', 'film_y_x', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "film_y_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_4', 'film_y_x', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "film_y_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_4', 'norm_e_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "norm_e_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_4', 'norm_e_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "norm_e_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_4', 'norm_e_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "norm_e_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_4', 'norm_e_2', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "norm_e_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_4', 'norm_x_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "norm_x_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_4', 'norm_x_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "norm_x_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_4', 'norm_x_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "norm_x_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_4', 'norm_x_2', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "norm_x_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_4', 'norm_y_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "norm_y_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_4', 'norm_y_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "norm_y_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_4', 'norm_y_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "norm_y_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_4', 'norm_y_2', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "norm_y_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_5', 'attn', 'e_add', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_add", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_5', 'attn', 'e_add', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_add", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_5', 'attn', 'e_mul', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_mul", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_5', 'attn', 'e_mul', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_mul", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_5', 'attn', 'e_out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_5', 'attn', 'e_out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_5', 'attn', 'e_to_y', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_to_y", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_5', 'attn', 'e_to_y', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_to_y", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_5', 'attn', 'k', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "k", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_5', 'attn', 'k', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "k", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_5', 'attn', 'q', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "q", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_5', 'attn', 'q', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "q", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_5', 'attn', 'v', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "v", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_5', 'attn', 'v', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "v", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_5', 'attn', 'x_out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_5', 'attn', 'x_out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_5', 'attn', 'x_to_y', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_to_y", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_5', 'attn', 'x_to_y', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_to_y", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_5', 'attn', 'y_e_add', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_add", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_5', 'attn', 'y_e_add', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_add", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_5', 'attn', 'y_e_mul', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_mul", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_5', 'attn', 'y_e_mul', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_mul", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_5', 'attn', 'y_out_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_5', 'attn', 'y_out_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_5', 'attn', 'y_out_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_5', 'attn', 'y_out_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_5', 'attn', 'y_x_add', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_add", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_5', 'attn', 'y_x_add', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_add", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_5', 'attn', 'y_x_mul', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_mul", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_5', 'attn', 'y_x_mul', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_mul", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_5', 'attn', 'y_y', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_y", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_5', 'attn', 'y_y', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_y", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_5', 'ffn_e_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "ffn_e_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_5', 'ffn_e_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "ffn_e_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_5', 'ffn_e_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "ffn_e_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_5', 'ffn_e_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "ffn_e_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_5', 'ffn_x_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "ffn_x_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_5', 'ffn_x_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "ffn_x_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_5', 'ffn_x_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "ffn_x_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_5', 'ffn_x_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "ffn_x_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_5', 'ffn_y_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "ffn_y_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_5', 'ffn_y_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "ffn_y_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_5', 'ffn_y_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "ffn_y_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_5', 'ffn_y_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "ffn_y_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_5', 'film_l_e', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "film_l_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_5', 'film_l_e', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "film_l_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_5', 'film_l_x', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "film_l_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_5', 'film_l_x', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "film_l_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_5', 'film_t_e', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "film_t_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_5', 'film_t_e', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "film_t_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_5', 'film_t_x', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "film_t_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_5', 'film_t_x', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "film_t_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_5', 'film_y_e', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "film_y_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_5', 'film_y_e', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "film_y_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_5', 'film_y_x', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "film_y_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_5', 'film_y_x', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "film_y_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_5', 'norm_e_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "norm_e_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_5', 'norm_e_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "norm_e_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_5', 'norm_e_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "norm_e_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_5', 'norm_e_2', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "norm_e_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_5', 'norm_x_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "norm_x_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_5', 'norm_x_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "norm_x_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_5', 'norm_x_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "norm_x_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_5', 'norm_x_2', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "norm_x_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_5', 'norm_y_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "norm_y_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_5', 'norm_y_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "norm_y_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_5', 'norm_y_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "norm_y_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_5', 'norm_y_2', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "norm_y_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_6', 'attn', 'e_add', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_add", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_6', 'attn', 'e_add', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_add", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_6', 'attn', 'e_mul', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_mul", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_6', 'attn', 'e_mul', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_mul", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_6', 'attn', 'e_out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_6', 'attn', 'e_out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_6', 'attn', 'e_to_y', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_to_y", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_6', 'attn', 'e_to_y', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_to_y", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_6', 'attn', 'k', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "k", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_6', 'attn', 'k', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "k", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_6', 'attn', 'q', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "q", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_6', 'attn', 'q', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "q", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_6', 'attn', 'v', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "v", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_6', 'attn', 'v', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "v", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_6', 'attn', 'x_out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_6', 'attn', 'x_out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_6', 'attn', 'x_to_y', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_to_y", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_6', 'attn', 'x_to_y', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_to_y", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_6', 'attn', 'y_e_add', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_add", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_6', 'attn', 'y_e_add', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_add", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_6', 'attn', 'y_e_mul', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_mul", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_6', 'attn', 'y_e_mul', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_mul", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_6', 'attn', 'y_out_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_6', 'attn', 'y_out_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_6', 'attn', 'y_out_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_6', 'attn', 'y_out_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_6', 'attn', 'y_x_add', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_add", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_6', 'attn', 'y_x_add', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_add", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_6', 'attn', 'y_x_mul', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_mul", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_6', 'attn', 'y_x_mul', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_mul", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_6', 'attn', 'y_y', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_y", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_6', 'attn', 'y_y', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_y", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_6', 'ffn_e_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "ffn_e_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_6', 'ffn_e_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "ffn_e_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_6', 'ffn_e_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "ffn_e_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_6', 'ffn_e_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "ffn_e_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_6', 'ffn_x_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "ffn_x_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_6', 'ffn_x_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "ffn_x_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_6', 'ffn_x_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "ffn_x_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_6', 'ffn_x_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "ffn_x_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_6', 'ffn_y_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "ffn_y_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_6', 'ffn_y_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "ffn_y_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_6', 'ffn_y_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "ffn_y_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_6', 'ffn_y_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "ffn_y_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_6', 'film_l_e', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "film_l_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_6', 'film_l_e', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "film_l_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_6', 'film_l_x', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "film_l_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_6', 'film_l_x', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "film_l_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_6', 'film_t_e', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "film_t_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_6', 'film_t_e', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "film_t_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_6', 'film_t_x', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "film_t_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_6', 'film_t_x', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "film_t_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_6', 'film_y_e', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "film_y_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_6', 'film_y_e', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "film_y_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_6', 'film_y_x', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "film_y_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_6', 'film_y_x', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "film_y_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_6', 'norm_e_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "norm_e_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_6', 'norm_e_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "norm_e_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_6', 'norm_e_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "norm_e_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_6', 'norm_e_2', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "norm_e_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_6', 'norm_x_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "norm_x_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_6', 'norm_x_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "norm_x_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_6', 'norm_x_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "norm_x_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_6', 'norm_x_2', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "norm_x_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_6', 'norm_y_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "norm_y_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_6', 'norm_y_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "norm_y_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_6', 'norm_y_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "norm_y_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_6', 'norm_y_2', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "norm_y_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_7', 'attn', 'e_add', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_add", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_7', 'attn', 'e_add', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_add", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_7', 'attn', 'e_mul', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_mul", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_7', 'attn', 'e_mul', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_mul", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_7', 'attn', 'e_out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_7', 'attn', 'e_out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_7', 'attn', 'e_to_y', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_to_y", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_7', 'attn', 'e_to_y', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_to_y", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_7', 'attn', 'k', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "k", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_7', 'attn', 'k', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "k", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_7', 'attn', 'q', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "q", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_7', 'attn', 'q', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "q", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_7', 'attn', 'v', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "v", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_7', 'attn', 'v', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "v", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_7', 'attn', 'x_out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_7', 'attn', 'x_out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_7', 'attn', 'x_to_y', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_to_y", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_7', 'attn', 'x_to_y', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_to_y", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_7', 'attn', 'y_e_add', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_add", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_7', 'attn', 'y_e_add', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_add", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_7', 'attn', 'y_e_mul', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_mul", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_7', 'attn', 'y_e_mul', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_mul", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_7', 'attn', 'y_out_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_7', 'attn', 'y_out_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_7', 'attn', 'y_out_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_7', 'attn', 'y_out_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_7', 'attn', 'y_x_add', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_add", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_7', 'attn', 'y_x_add', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_add", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_7', 'attn', 'y_x_mul', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_mul", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_7', 'attn', 'y_x_mul', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_mul", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_7', 'attn', 'y_y', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_y", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_7', 'attn', 'y_y', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_y", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_7', 'ffn_e_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "ffn_e_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_7', 'ffn_e_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "ffn_e_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_7', 'ffn_e_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "ffn_e_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_7', 'ffn_e_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "ffn_e_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_7', 'ffn_x_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "ffn_x_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_7', 'ffn_x_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "ffn_x_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_7', 'ffn_x_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "ffn_x_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_7', 'ffn_x_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "ffn_x_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_7', 'ffn_y_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "ffn_y_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_7', 'ffn_y_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "ffn_y_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_7', 'ffn_y_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "ffn_y_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_7', 'ffn_y_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "ffn_y_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_7', 'film_l_e', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "film_l_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_7', 'film_l_e', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "film_l_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_7', 'film_l_x', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "film_l_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_7', 'film_l_x', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "film_l_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_7', 'film_t_e', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "film_t_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_7', 'film_t_e', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "film_t_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_7', 'film_t_x', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "film_t_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_7', 'film_t_x', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "film_t_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_7', 'film_y_e', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "film_y_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_7', 'film_y_e', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "film_y_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_7', 'film_y_x', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "film_y_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_7', 'film_y_x', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "film_y_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_7', 'norm_e_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "norm_e_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_7', 'norm_e_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "norm_e_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_7', 'norm_e_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "norm_e_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_7', 'norm_e_2', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "norm_e_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_7', 'norm_x_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "norm_x_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_7', 'norm_x_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "norm_x_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_7', 'norm_x_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "norm_x_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_7', 'norm_x_2', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "norm_x_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_7', 'norm_y_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "norm_y_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_7', 'norm_y_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "norm_y_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_7', 'norm_y_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "norm_y_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'layer_7', 'norm_y_2', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "norm_y_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'pos_emb', 'embedding')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "pos_emb", "key_type": 2}, {"key": "embedding", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'time_emb', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "time_emb", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'time_emb', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "time_emb", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'time_emb', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "time_emb", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'time_emb', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "time_emb", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'y_init', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "y_init", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'mu', 'params', 'y_init', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "y_init", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'edge_emb', 'embedding')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "edge_emb", "key_type": 2}, {"key": "embedding", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'head', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "head", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'head', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "head", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'lambda_emb', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "lambda_emb", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'lambda_emb', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "lambda_emb", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'lambda_emb', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "lambda_emb", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'lambda_emb', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "lambda_emb", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_0', 'attn', 'e_add', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_add", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_0', 'attn', 'e_add', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_add", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_0', 'attn', 'e_mul', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_mul", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_0', 'attn', 'e_mul', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_mul", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_0', 'attn', 'e_out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_0', 'attn', 'e_out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_0', 'attn', 'e_to_y', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_to_y", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_0', 'attn', 'e_to_y', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_to_y", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_0', 'attn', 'k', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "k", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_0', 'attn', 'k', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "k", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_0', 'attn', 'q', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "q", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_0', 'attn', 'q', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "q", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_0', 'attn', 'v', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "v", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_0', 'attn', 'v', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "v", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_0', 'attn', 'x_out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_0', 'attn', 'x_out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_0', 'attn', 'x_to_y', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_to_y", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_0', 'attn', 'x_to_y', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_to_y", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_0', 'attn', 'y_e_add', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_add", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_0', 'attn', 'y_e_add', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_add", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_0', 'attn', 'y_e_mul', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_mul", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_0', 'attn', 'y_e_mul', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_mul", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_0', 'attn', 'y_out_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_0', 'attn', 'y_out_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_0', 'attn', 'y_out_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_0', 'attn', 'y_out_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_0', 'attn', 'y_x_add', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_add", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_0', 'attn', 'y_x_add', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_add", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_0', 'attn', 'y_x_mul', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_mul", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_0', 'attn', 'y_x_mul', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_mul", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_0', 'attn', 'y_y', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_y", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_0', 'attn', 'y_y', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_y", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_0', 'ffn_e_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "ffn_e_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_0', 'ffn_e_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "ffn_e_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_0', 'ffn_e_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "ffn_e_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_0', 'ffn_e_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "ffn_e_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_0', 'ffn_x_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "ffn_x_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_0', 'ffn_x_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "ffn_x_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_0', 'ffn_x_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "ffn_x_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_0', 'ffn_x_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "ffn_x_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_0', 'ffn_y_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "ffn_y_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_0', 'ffn_y_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "ffn_y_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_0', 'ffn_y_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "ffn_y_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_0', 'ffn_y_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "ffn_y_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_0', 'film_l_e', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "film_l_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_0', 'film_l_e', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "film_l_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_0', 'film_l_x', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "film_l_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_0', 'film_l_x', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "film_l_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_0', 'film_t_e', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "film_t_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_0', 'film_t_e', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "film_t_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_0', 'film_t_x', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "film_t_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_0', 'film_t_x', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "film_t_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_0', 'film_y_e', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "film_y_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_0', 'film_y_e', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "film_y_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_0', 'film_y_x', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "film_y_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_0', 'film_y_x', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "film_y_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_0', 'norm_e_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "norm_e_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_0', 'norm_e_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "norm_e_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_0', 'norm_e_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "norm_e_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_0', 'norm_e_2', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "norm_e_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_0', 'norm_x_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "norm_x_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_0', 'norm_x_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "norm_x_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_0', 'norm_x_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "norm_x_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_0', 'norm_x_2', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "norm_x_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_0', 'norm_y_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "norm_y_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_0', 'norm_y_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "norm_y_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_0', 'norm_y_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "norm_y_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_0', 'norm_y_2', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "norm_y_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_1', 'attn', 'e_add', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_add", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_1', 'attn', 'e_add', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_add", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_1', 'attn', 'e_mul', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_mul", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_1', 'attn', 'e_mul', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_mul", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_1', 'attn', 'e_out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_1', 'attn', 'e_out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_1', 'attn', 'e_to_y', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_to_y", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_1', 'attn', 'e_to_y', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_to_y", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_1', 'attn', 'k', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "k", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_1', 'attn', 'k', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "k", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_1', 'attn', 'q', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "q", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_1', 'attn', 'q', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "q", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_1', 'attn', 'v', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "v", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_1', 'attn', 'v', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "v", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_1', 'attn', 'x_out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_1', 'attn', 'x_out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_1', 'attn', 'x_to_y', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_to_y", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_1', 'attn', 'x_to_y', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_to_y", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_1', 'attn', 'y_e_add', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_add", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_1', 'attn', 'y_e_add', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_add", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_1', 'attn', 'y_e_mul', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_mul", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_1', 'attn', 'y_e_mul', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_mul", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_1', 'attn', 'y_out_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_1', 'attn', 'y_out_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_1', 'attn', 'y_out_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_1', 'attn', 'y_out_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_1', 'attn', 'y_x_add', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_add", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_1', 'attn', 'y_x_add', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_add", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_1', 'attn', 'y_x_mul', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_mul", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_1', 'attn', 'y_x_mul', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_mul", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_1', 'attn', 'y_y', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_y", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_1', 'attn', 'y_y', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_y", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_1', 'ffn_e_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "ffn_e_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_1', 'ffn_e_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "ffn_e_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_1', 'ffn_e_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "ffn_e_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_1', 'ffn_e_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "ffn_e_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_1', 'ffn_x_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "ffn_x_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_1', 'ffn_x_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "ffn_x_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_1', 'ffn_x_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "ffn_x_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_1', 'ffn_x_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "ffn_x_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_1', 'ffn_y_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "ffn_y_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_1', 'ffn_y_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "ffn_y_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_1', 'ffn_y_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "ffn_y_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_1', 'ffn_y_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "ffn_y_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_1', 'film_l_e', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "film_l_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_1', 'film_l_e', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "film_l_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_1', 'film_l_x', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "film_l_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_1', 'film_l_x', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "film_l_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_1', 'film_t_e', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "film_t_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_1', 'film_t_e', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "film_t_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_1', 'film_t_x', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "film_t_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_1', 'film_t_x', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "film_t_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_1', 'film_y_e', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "film_y_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_1', 'film_y_e', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "film_y_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_1', 'film_y_x', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "film_y_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_1', 'film_y_x', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "film_y_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_1', 'norm_e_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "norm_e_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_1', 'norm_e_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "norm_e_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_1', 'norm_e_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "norm_e_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_1', 'norm_e_2', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "norm_e_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_1', 'norm_x_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "norm_x_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_1', 'norm_x_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "norm_x_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_1', 'norm_x_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "norm_x_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_1', 'norm_x_2', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "norm_x_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_1', 'norm_y_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "norm_y_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_1', 'norm_y_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "norm_y_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_1', 'norm_y_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "norm_y_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_1', 'norm_y_2', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "norm_y_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_2', 'attn', 'e_add', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_add", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_2', 'attn', 'e_add', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_add", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_2', 'attn', 'e_mul', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_mul", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_2', 'attn', 'e_mul', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_mul", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_2', 'attn', 'e_out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_2', 'attn', 'e_out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_2', 'attn', 'e_to_y', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_to_y", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_2', 'attn', 'e_to_y', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_to_y", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_2', 'attn', 'k', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "k", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_2', 'attn', 'k', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "k", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_2', 'attn', 'q', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "q", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_2', 'attn', 'q', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "q", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_2', 'attn', 'v', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "v", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_2', 'attn', 'v', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "v", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_2', 'attn', 'x_out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_2', 'attn', 'x_out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_2', 'attn', 'x_to_y', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_to_y", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_2', 'attn', 'x_to_y', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_to_y", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_2', 'attn', 'y_e_add', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_add", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_2', 'attn', 'y_e_add', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_add", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_2', 'attn', 'y_e_mul', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_mul", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_2', 'attn', 'y_e_mul', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_mul", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_2', 'attn', 'y_out_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_2', 'attn', 'y_out_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_2', 'attn', 'y_out_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_2', 'attn', 'y_out_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_2', 'attn', 'y_x_add', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_add", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_2', 'attn', 'y_x_add', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_add", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_2', 'attn', 'y_x_mul', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_mul", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_2', 'attn', 'y_x_mul', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_mul", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_2', 'attn', 'y_y', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_y", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_2', 'attn', 'y_y', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_y", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_2', 'ffn_e_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "ffn_e_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_2', 'ffn_e_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "ffn_e_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_2', 'ffn_e_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "ffn_e_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_2', 'ffn_e_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "ffn_e_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_2', 'ffn_x_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "ffn_x_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_2', 'ffn_x_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "ffn_x_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_2', 'ffn_x_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "ffn_x_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_2', 'ffn_x_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "ffn_x_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_2', 'ffn_y_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "ffn_y_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_2', 'ffn_y_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "ffn_y_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_2', 'ffn_y_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "ffn_y_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_2', 'ffn_y_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "ffn_y_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_2', 'film_l_e', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "film_l_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_2', 'film_l_e', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "film_l_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_2', 'film_l_x', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "film_l_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_2', 'film_l_x', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "film_l_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_2', 'film_t_e', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "film_t_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_2', 'film_t_e', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "film_t_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_2', 'film_t_x', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "film_t_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_2', 'film_t_x', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "film_t_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_2', 'film_y_e', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "film_y_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_2', 'film_y_e', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "film_y_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_2', 'film_y_x', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "film_y_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_2', 'film_y_x', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "film_y_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_2', 'norm_e_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "norm_e_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_2', 'norm_e_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "norm_e_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_2', 'norm_e_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "norm_e_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_2', 'norm_e_2', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "norm_e_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_2', 'norm_x_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "norm_x_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_2', 'norm_x_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "norm_x_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_2', 'norm_x_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "norm_x_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_2', 'norm_x_2', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "norm_x_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_2', 'norm_y_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "norm_y_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_2', 'norm_y_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "norm_y_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_2', 'norm_y_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "norm_y_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_2', 'norm_y_2', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "norm_y_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_3', 'attn', 'e_add', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_add", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_3', 'attn', 'e_add', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_add", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_3', 'attn', 'e_mul', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_mul", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_3', 'attn', 'e_mul', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_mul", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_3', 'attn', 'e_out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_3', 'attn', 'e_out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_3', 'attn', 'e_to_y', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_to_y", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_3', 'attn', 'e_to_y', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_to_y", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_3', 'attn', 'k', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "k", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_3', 'attn', 'k', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "k", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_3', 'attn', 'q', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "q", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_3', 'attn', 'q', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "q", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_3', 'attn', 'v', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "v", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_3', 'attn', 'v', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "v", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_3', 'attn', 'x_out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_3', 'attn', 'x_out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_3', 'attn', 'x_to_y', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_to_y", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_3', 'attn', 'x_to_y', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_to_y", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_3', 'attn', 'y_e_add', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_add", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_3', 'attn', 'y_e_add', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_add", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_3', 'attn', 'y_e_mul', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_mul", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_3', 'attn', 'y_e_mul', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_mul", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_3', 'attn', 'y_out_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_3', 'attn', 'y_out_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_3', 'attn', 'y_out_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_3', 'attn', 'y_out_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_3', 'attn', 'y_x_add', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_add", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_3', 'attn', 'y_x_add', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_add", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_3', 'attn', 'y_x_mul', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_mul", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_3', 'attn', 'y_x_mul', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_mul", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_3', 'attn', 'y_y', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_y", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_3', 'attn', 'y_y', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_y", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_3', 'ffn_e_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "ffn_e_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_3', 'ffn_e_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "ffn_e_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_3', 'ffn_e_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "ffn_e_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_3', 'ffn_e_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "ffn_e_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_3', 'ffn_x_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "ffn_x_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_3', 'ffn_x_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "ffn_x_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_3', 'ffn_x_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "ffn_x_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_3', 'ffn_x_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "ffn_x_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_3', 'ffn_y_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "ffn_y_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_3', 'ffn_y_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "ffn_y_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_3', 'ffn_y_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "ffn_y_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_3', 'ffn_y_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "ffn_y_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_3', 'film_l_e', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "film_l_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_3', 'film_l_e', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "film_l_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_3', 'film_l_x', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "film_l_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_3', 'film_l_x', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "film_l_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_3', 'film_t_e', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "film_t_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_3', 'film_t_e', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "film_t_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_3', 'film_t_x', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "film_t_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_3', 'film_t_x', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "film_t_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_3', 'film_y_e', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "film_y_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_3', 'film_y_e', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "film_y_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_3', 'film_y_x', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "film_y_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_3', 'film_y_x', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "film_y_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_3', 'norm_e_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "norm_e_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_3', 'norm_e_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "norm_e_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_3', 'norm_e_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "norm_e_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_3', 'norm_e_2', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "norm_e_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_3', 'norm_x_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "norm_x_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_3', 'norm_x_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "norm_x_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_3', 'norm_x_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "norm_x_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_3', 'norm_x_2', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "norm_x_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_3', 'norm_y_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "norm_y_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_3', 'norm_y_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "norm_y_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_3', 'norm_y_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "norm_y_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_3', 'norm_y_2', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "norm_y_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_4', 'attn', 'e_add', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_add", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_4', 'attn', 'e_add', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_add", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_4', 'attn', 'e_mul', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_mul", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_4', 'attn', 'e_mul', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_mul", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_4', 'attn', 'e_out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_4', 'attn', 'e_out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_4', 'attn', 'e_to_y', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_to_y", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_4', 'attn', 'e_to_y', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_to_y", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_4', 'attn', 'k', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "k", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_4', 'attn', 'k', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "k", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_4', 'attn', 'q', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "q", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_4', 'attn', 'q', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "q", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_4', 'attn', 'v', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "v", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_4', 'attn', 'v', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "v", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_4', 'attn', 'x_out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_4', 'attn', 'x_out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_4', 'attn', 'x_to_y', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_to_y", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_4', 'attn', 'x_to_y', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_to_y", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_4', 'attn', 'y_e_add', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_add", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_4', 'attn', 'y_e_add', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_add", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_4', 'attn', 'y_e_mul', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_mul", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_4', 'attn', 'y_e_mul', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_mul", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_4', 'attn', 'y_out_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_4', 'attn', 'y_out_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_4', 'attn', 'y_out_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_4', 'attn', 'y_out_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_4', 'attn', 'y_x_add', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_add", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_4', 'attn', 'y_x_add', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_add", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_4', 'attn', 'y_x_mul', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_mul", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_4', 'attn', 'y_x_mul', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_mul", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_4', 'attn', 'y_y', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_y", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_4', 'attn', 'y_y', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_y", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_4', 'ffn_e_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "ffn_e_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_4', 'ffn_e_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "ffn_e_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_4', 'ffn_e_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "ffn_e_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_4', 'ffn_e_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "ffn_e_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_4', 'ffn_x_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "ffn_x_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_4', 'ffn_x_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "ffn_x_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_4', 'ffn_x_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "ffn_x_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_4', 'ffn_x_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "ffn_x_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_4', 'ffn_y_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "ffn_y_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_4', 'ffn_y_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "ffn_y_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_4', 'ffn_y_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "ffn_y_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_4', 'ffn_y_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "ffn_y_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_4', 'film_l_e', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "film_l_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_4', 'film_l_e', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "film_l_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_4', 'film_l_x', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "film_l_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_4', 'film_l_x', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "film_l_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_4', 'film_t_e', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "film_t_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_4', 'film_t_e', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "film_t_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_4', 'film_t_x', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "film_t_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_4', 'film_t_x', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "film_t_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_4', 'film_y_e', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "film_y_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_4', 'film_y_e', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "film_y_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_4', 'film_y_x', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "film_y_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_4', 'film_y_x', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "film_y_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_4', 'norm_e_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "norm_e_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_4', 'norm_e_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "norm_e_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_4', 'norm_e_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "norm_e_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_4', 'norm_e_2', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "norm_e_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_4', 'norm_x_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "norm_x_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_4', 'norm_x_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "norm_x_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_4', 'norm_x_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "norm_x_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_4', 'norm_x_2', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "norm_x_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_4', 'norm_y_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "norm_y_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_4', 'norm_y_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "norm_y_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_4', 'norm_y_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "norm_y_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_4', 'norm_y_2', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "norm_y_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_5', 'attn', 'e_add', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_add", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_5', 'attn', 'e_add', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_add", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_5', 'attn', 'e_mul', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_mul", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_5', 'attn', 'e_mul', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_mul", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_5', 'attn', 'e_out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_5', 'attn', 'e_out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_5', 'attn', 'e_to_y', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_to_y", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_5', 'attn', 'e_to_y', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_to_y", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_5', 'attn', 'k', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "k", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_5', 'attn', 'k', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "k", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_5', 'attn', 'q', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "q", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_5', 'attn', 'q', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "q", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_5', 'attn', 'v', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "v", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_5', 'attn', 'v', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "v", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_5', 'attn', 'x_out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_5', 'attn', 'x_out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_5', 'attn', 'x_to_y', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_to_y", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_5', 'attn', 'x_to_y', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_to_y", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_5', 'attn', 'y_e_add', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_add", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_5', 'attn', 'y_e_add', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_add", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_5', 'attn', 'y_e_mul', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_mul", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_5', 'attn', 'y_e_mul', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_mul", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_5', 'attn', 'y_out_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_5', 'attn', 'y_out_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_5', 'attn', 'y_out_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_5', 'attn', 'y_out_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_5', 'attn', 'y_x_add', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_add", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_5', 'attn', 'y_x_add', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_add", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_5', 'attn', 'y_x_mul', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_mul", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_5', 'attn', 'y_x_mul', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_mul", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_5', 'attn', 'y_y', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_y", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_5', 'attn', 'y_y', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_y", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_5', 'ffn_e_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "ffn_e_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_5', 'ffn_e_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "ffn_e_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_5', 'ffn_e_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "ffn_e_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_5', 'ffn_e_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "ffn_e_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_5', 'ffn_x_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "ffn_x_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_5', 'ffn_x_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "ffn_x_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_5', 'ffn_x_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "ffn_x_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_5', 'ffn_x_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "ffn_x_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_5', 'ffn_y_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "ffn_y_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_5', 'ffn_y_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "ffn_y_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_5', 'ffn_y_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "ffn_y_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_5', 'ffn_y_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "ffn_y_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_5', 'film_l_e', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "film_l_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_5', 'film_l_e', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "film_l_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_5', 'film_l_x', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "film_l_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_5', 'film_l_x', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "film_l_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_5', 'film_t_e', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "film_t_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_5', 'film_t_e', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "film_t_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_5', 'film_t_x', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "film_t_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_5', 'film_t_x', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "film_t_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_5', 'film_y_e', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "film_y_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_5', 'film_y_e', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "film_y_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_5', 'film_y_x', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "film_y_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_5', 'film_y_x', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "film_y_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_5', 'norm_e_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "norm_e_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_5', 'norm_e_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "norm_e_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_5', 'norm_e_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "norm_e_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_5', 'norm_e_2', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "norm_e_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_5', 'norm_x_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "norm_x_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_5', 'norm_x_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "norm_x_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_5', 'norm_x_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "norm_x_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_5', 'norm_x_2', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "norm_x_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_5', 'norm_y_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "norm_y_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_5', 'norm_y_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "norm_y_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_5', 'norm_y_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "norm_y_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_5', 'norm_y_2', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "norm_y_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_6', 'attn', 'e_add', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_add", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_6', 'attn', 'e_add', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_add", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_6', 'attn', 'e_mul', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_mul", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_6', 'attn', 'e_mul', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_mul", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_6', 'attn', 'e_out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_6', 'attn', 'e_out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_6', 'attn', 'e_to_y', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_to_y", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_6', 'attn', 'e_to_y', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_to_y", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_6', 'attn', 'k', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "k", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_6', 'attn', 'k', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "k", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_6', 'attn', 'q', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "q", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_6', 'attn', 'q', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "q", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_6', 'attn', 'v', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "v", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_6', 'attn', 'v', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "v", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_6', 'attn', 'x_out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_6', 'attn', 'x_out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_6', 'attn', 'x_to_y', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_to_y", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_6', 'attn', 'x_to_y', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_to_y", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_6', 'attn', 'y_e_add', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_add", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_6', 'attn', 'y_e_add', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_add", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_6', 'attn', 'y_e_mul', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_mul", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_6', 'attn', 'y_e_mul', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_mul", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_6', 'attn', 'y_out_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_6', 'attn', 'y_out_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_6', 'attn', 'y_out_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_6', 'attn', 'y_out_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_6', 'attn', 'y_x_add', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_add", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_6', 'attn', 'y_x_add', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_add", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_6', 'attn', 'y_x_mul', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_mul", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_6', 'attn', 'y_x_mul', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_mul", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_6', 'attn', 'y_y', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_y", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_6', 'attn', 'y_y', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_y", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_6', 'ffn_e_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "ffn_e_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_6', 'ffn_e_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "ffn_e_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_6', 'ffn_e_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "ffn_e_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_6', 'ffn_e_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "ffn_e_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_6', 'ffn_x_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "ffn_x_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_6', 'ffn_x_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "ffn_x_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_6', 'ffn_x_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "ffn_x_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_6', 'ffn_x_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "ffn_x_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_6', 'ffn_y_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "ffn_y_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_6', 'ffn_y_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "ffn_y_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_6', 'ffn_y_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "ffn_y_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_6', 'ffn_y_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "ffn_y_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_6', 'film_l_e', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "film_l_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_6', 'film_l_e', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "film_l_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_6', 'film_l_x', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "film_l_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_6', 'film_l_x', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "film_l_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_6', 'film_t_e', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "film_t_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_6', 'film_t_e', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "film_t_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_6', 'film_t_x', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "film_t_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_6', 'film_t_x', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "film_t_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_6', 'film_y_e', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "film_y_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_6', 'film_y_e', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "film_y_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_6', 'film_y_x', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "film_y_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_6', 'film_y_x', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "film_y_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_6', 'norm_e_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "norm_e_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_6', 'norm_e_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "norm_e_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_6', 'norm_e_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "norm_e_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_6', 'norm_e_2', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "norm_e_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_6', 'norm_x_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "norm_x_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_6', 'norm_x_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "norm_x_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_6', 'norm_x_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "norm_x_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_6', 'norm_x_2', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "norm_x_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_6', 'norm_y_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "norm_y_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_6', 'norm_y_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "norm_y_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_6', 'norm_y_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "norm_y_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_6', 'norm_y_2', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "norm_y_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_7', 'attn', 'e_add', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_add", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_7', 'attn', 'e_add', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_add", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_7', 'attn', 'e_mul', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_mul", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_7', 'attn', 'e_mul', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_mul", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_7', 'attn', 'e_out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_7', 'attn', 'e_out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_7', 'attn', 'e_to_y', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_to_y", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_7', 'attn', 'e_to_y', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_to_y", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_7', 'attn', 'k', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "k", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_7', 'attn', 'k', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "k", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_7', 'attn', 'q', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "q", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_7', 'attn', 'q', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "q", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_7', 'attn', 'v', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "v", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_7', 'attn', 'v', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "v", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_7', 'attn', 'x_out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_7', 'attn', 'x_out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_7', 'attn', 'x_to_y', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_to_y", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_7', 'attn', 'x_to_y', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_to_y", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_7', 'attn', 'y_e_add', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_add", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_7', 'attn', 'y_e_add', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_add", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_7', 'attn', 'y_e_mul', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_mul", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_7', 'attn', 'y_e_mul', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_mul", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_7', 'attn', 'y_out_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_7', 'attn', 'y_out_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_7', 'attn', 'y_out_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_7', 'attn', 'y_out_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_7', 'attn', 'y_x_add', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_add", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_7', 'attn', 'y_x_add', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_add", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_7', 'attn', 'y_x_mul', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_mul", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_7', 'attn', 'y_x_mul', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_mul", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_7', 'attn', 'y_y', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_y", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_7', 'attn', 'y_y', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_y", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_7', 'ffn_e_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "ffn_e_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_7', 'ffn_e_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "ffn_e_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_7', 'ffn_e_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "ffn_e_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_7', 'ffn_e_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "ffn_e_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_7', 'ffn_x_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "ffn_x_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_7', 'ffn_x_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "ffn_x_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_7', 'ffn_x_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "ffn_x_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_7', 'ffn_x_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "ffn_x_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_7', 'ffn_y_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "ffn_y_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_7', 'ffn_y_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "ffn_y_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_7', 'ffn_y_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "ffn_y_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_7', 'ffn_y_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "ffn_y_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_7', 'film_l_e', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "film_l_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_7', 'film_l_e', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "film_l_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_7', 'film_l_x', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "film_l_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_7', 'film_l_x', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "film_l_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_7', 'film_t_e', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "film_t_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_7', 'film_t_e', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "film_t_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_7', 'film_t_x', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "film_t_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_7', 'film_t_x', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "film_t_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_7', 'film_y_e', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "film_y_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_7', 'film_y_e', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "film_y_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_7', 'film_y_x', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "film_y_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_7', 'film_y_x', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "film_y_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_7', 'norm_e_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "norm_e_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_7', 'norm_e_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "norm_e_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_7', 'norm_e_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "norm_e_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_7', 'norm_e_2', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "norm_e_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_7', 'norm_x_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "norm_x_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_7', 'norm_x_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "norm_x_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_7', 'norm_x_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "norm_x_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_7', 'norm_x_2', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "norm_x_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_7', 'norm_y_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "norm_y_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_7', 'norm_y_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "norm_y_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_7', 'norm_y_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "norm_y_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'layer_7', 'norm_y_2', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "norm_y_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'pos_emb', 'embedding')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "pos_emb", "key_type": 2}, {"key": "embedding", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'time_emb', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "time_emb", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'time_emb', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "time_emb", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'time_emb', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "time_emb", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'time_emb', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "time_emb", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'y_init', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "y_init", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '0', 'nu', 'params', 'y_init', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "y_init", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('opt_state', '1', '1')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "1", "key_type": 1}], "value_metadata": {"value_type": "None", "skip_deserialize": true}}, "('opt_state', '1', '2')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "2", "key_type": 1}], "value_metadata": {"value_type": "None", "skip_deserialize": true}}, "('params', 'params', 'edge_emb', 'embedding')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "edge_emb", "key_type": 2}, {"key": "embedding", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'head', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "head", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'head', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "head", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'lambda_emb', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "lambda_emb", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'lambda_emb', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "lambda_emb", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'lambda_emb', 'Dense_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "lambda_emb", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'lambda_emb', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "lambda_emb", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_0', 'attn', 'e_add', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_add", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_0', 'attn', 'e_add', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_add", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_0', 'attn', 'e_mul', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_mul", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_0', 'attn', 'e_mul', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_mul", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_0', 'attn', 'e_out', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_0', 'attn', 'e_out', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_0', 'attn', 'e_to_y', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_to_y", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_0', 'attn', 'e_to_y', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_to_y", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_0', 'attn', 'k', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "k", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_0', 'attn', 'k', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "k", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_0', 'attn', 'q', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "q", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_0', 'attn', 'q', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "q", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_0', 'attn', 'v', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "v", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_0', 'attn', 'v', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "v", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_0', 'attn', 'x_out', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_0', 'attn', 'x_out', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_0', 'attn', 'x_to_y', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_to_y", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_0', 'attn', 'x_to_y', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_to_y", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_0', 'attn', 'y_e_add', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_add", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_0', 'attn', 'y_e_add', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_add", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_0', 'attn', 'y_e_mul', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_mul", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_0', 'attn', 'y_e_mul', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_mul", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_0', 'attn', 'y_out_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_0', 'attn', 'y_out_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_0', 'attn', 'y_out_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_0', 'attn', 'y_out_2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_0', 'attn', 'y_x_add', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_add", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_0', 'attn', 'y_x_add', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_add", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_0', 'attn', 'y_x_mul', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_mul", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_0', 'attn', 'y_x_mul', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_mul", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_0', 'attn', 'y_y', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_y", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_0', 'attn', 'y_y', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_y", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_0', 'ffn_e_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "ffn_e_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_0', 'ffn_e_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "ffn_e_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_0', 'ffn_e_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "ffn_e_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_0', 'ffn_e_2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "ffn_e_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_0', 'ffn_x_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "ffn_x_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_0', 'ffn_x_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "ffn_x_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_0', 'ffn_x_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "ffn_x_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_0', 'ffn_x_2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "ffn_x_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_0', 'ffn_y_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "ffn_y_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_0', 'ffn_y_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "ffn_y_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_0', 'ffn_y_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "ffn_y_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_0', 'ffn_y_2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "ffn_y_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_0', 'film_l_e', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "film_l_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_0', 'film_l_e', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "film_l_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_0', 'film_l_x', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "film_l_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_0', 'film_l_x', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "film_l_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_0', 'film_t_e', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "film_t_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_0', 'film_t_e', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "film_t_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_0', 'film_t_x', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "film_t_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_0', 'film_t_x', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "film_t_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_0', 'film_y_e', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "film_y_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_0', 'film_y_e', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "film_y_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_0', 'film_y_x', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "film_y_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_0', 'film_y_x', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "film_y_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_0', 'norm_e_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "norm_e_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_0', 'norm_e_1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "norm_e_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_0', 'norm_e_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "norm_e_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_0', 'norm_e_2', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "norm_e_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_0', 'norm_x_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "norm_x_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_0', 'norm_x_1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "norm_x_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_0', 'norm_x_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "norm_x_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_0', 'norm_x_2', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "norm_x_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_0', 'norm_y_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "norm_y_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_0', 'norm_y_1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "norm_y_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_0', 'norm_y_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "norm_y_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_0', 'norm_y_2', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_0", "key_type": 2}, {"key": "norm_y_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_1', 'attn', 'e_add', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_add", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_1', 'attn', 'e_add', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_add", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_1', 'attn', 'e_mul', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_mul", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_1', 'attn', 'e_mul', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_mul", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_1', 'attn', 'e_out', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_1', 'attn', 'e_out', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_1', 'attn', 'e_to_y', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_to_y", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_1', 'attn', 'e_to_y', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_to_y", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_1', 'attn', 'k', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "k", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_1', 'attn', 'k', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "k", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_1', 'attn', 'q', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "q", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_1', 'attn', 'q', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "q", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_1', 'attn', 'v', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "v", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_1', 'attn', 'v', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "v", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_1', 'attn', 'x_out', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_1', 'attn', 'x_out', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_1', 'attn', 'x_to_y', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_to_y", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_1', 'attn', 'x_to_y', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_to_y", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_1', 'attn', 'y_e_add', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_add", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_1', 'attn', 'y_e_add', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_add", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_1', 'attn', 'y_e_mul', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_mul", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_1', 'attn', 'y_e_mul', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_mul", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_1', 'attn', 'y_out_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_1', 'attn', 'y_out_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_1', 'attn', 'y_out_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_1', 'attn', 'y_out_2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_1', 'attn', 'y_x_add', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_add", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_1', 'attn', 'y_x_add', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_add", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_1', 'attn', 'y_x_mul', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_mul", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_1', 'attn', 'y_x_mul', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_mul", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_1', 'attn', 'y_y', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_y", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_1', 'attn', 'y_y', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_y", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_1', 'ffn_e_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "ffn_e_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_1', 'ffn_e_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "ffn_e_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_1', 'ffn_e_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "ffn_e_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_1', 'ffn_e_2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "ffn_e_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_1', 'ffn_x_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "ffn_x_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_1', 'ffn_x_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "ffn_x_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_1', 'ffn_x_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "ffn_x_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_1', 'ffn_x_2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "ffn_x_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_1', 'ffn_y_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "ffn_y_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_1', 'ffn_y_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "ffn_y_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_1', 'ffn_y_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "ffn_y_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_1', 'ffn_y_2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "ffn_y_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_1', 'film_l_e', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "film_l_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_1', 'film_l_e', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "film_l_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_1', 'film_l_x', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "film_l_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_1', 'film_l_x', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "film_l_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_1', 'film_t_e', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "film_t_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_1', 'film_t_e', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "film_t_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_1', 'film_t_x', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "film_t_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_1', 'film_t_x', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "film_t_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_1', 'film_y_e', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "film_y_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_1', 'film_y_e', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "film_y_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_1', 'film_y_x', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "film_y_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_1', 'film_y_x', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "film_y_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_1', 'norm_e_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "norm_e_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_1', 'norm_e_1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "norm_e_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_1', 'norm_e_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "norm_e_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_1', 'norm_e_2', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "norm_e_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_1', 'norm_x_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "norm_x_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_1', 'norm_x_1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "norm_x_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_1', 'norm_x_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "norm_x_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_1', 'norm_x_2', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "norm_x_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_1', 'norm_y_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "norm_y_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_1', 'norm_y_1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "norm_y_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_1', 'norm_y_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "norm_y_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_1', 'norm_y_2', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_1", "key_type": 2}, {"key": "norm_y_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_2', 'attn', 'e_add', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_add", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_2', 'attn', 'e_add', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_add", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_2', 'attn', 'e_mul', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_mul", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_2', 'attn', 'e_mul', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_mul", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_2', 'attn', 'e_out', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_2', 'attn', 'e_out', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_2', 'attn', 'e_to_y', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_to_y", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_2', 'attn', 'e_to_y', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_to_y", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_2', 'attn', 'k', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "k", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_2', 'attn', 'k', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "k", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_2', 'attn', 'q', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "q", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_2', 'attn', 'q', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "q", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_2', 'attn', 'v', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "v", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_2', 'attn', 'v', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "v", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_2', 'attn', 'x_out', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_2', 'attn', 'x_out', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_2', 'attn', 'x_to_y', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_to_y", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_2', 'attn', 'x_to_y', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_to_y", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_2', 'attn', 'y_e_add', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_add", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_2', 'attn', 'y_e_add', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_add", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_2', 'attn', 'y_e_mul', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_mul", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_2', 'attn', 'y_e_mul', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_mul", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_2', 'attn', 'y_out_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_2', 'attn', 'y_out_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_2', 'attn', 'y_out_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_2', 'attn', 'y_out_2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_2', 'attn', 'y_x_add', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_add", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_2', 'attn', 'y_x_add', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_add", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_2', 'attn', 'y_x_mul', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_mul", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_2', 'attn', 'y_x_mul', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_mul", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_2', 'attn', 'y_y', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_y", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_2', 'attn', 'y_y', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_y", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_2', 'ffn_e_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "ffn_e_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_2', 'ffn_e_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "ffn_e_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_2', 'ffn_e_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "ffn_e_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_2', 'ffn_e_2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "ffn_e_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_2', 'ffn_x_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "ffn_x_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_2', 'ffn_x_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "ffn_x_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_2', 'ffn_x_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "ffn_x_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_2', 'ffn_x_2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "ffn_x_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_2', 'ffn_y_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "ffn_y_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_2', 'ffn_y_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "ffn_y_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_2', 'ffn_y_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "ffn_y_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_2', 'ffn_y_2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "ffn_y_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_2', 'film_l_e', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "film_l_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_2', 'film_l_e', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "film_l_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_2', 'film_l_x', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "film_l_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_2', 'film_l_x', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "film_l_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_2', 'film_t_e', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "film_t_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_2', 'film_t_e', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "film_t_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_2', 'film_t_x', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "film_t_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_2', 'film_t_x', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "film_t_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_2', 'film_y_e', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "film_y_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_2', 'film_y_e', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "film_y_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_2', 'film_y_x', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "film_y_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_2', 'film_y_x', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "film_y_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_2', 'norm_e_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "norm_e_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_2', 'norm_e_1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "norm_e_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_2', 'norm_e_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "norm_e_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_2', 'norm_e_2', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "norm_e_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_2', 'norm_x_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "norm_x_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_2', 'norm_x_1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "norm_x_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_2', 'norm_x_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "norm_x_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_2', 'norm_x_2', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "norm_x_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_2', 'norm_y_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "norm_y_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_2', 'norm_y_1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "norm_y_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_2', 'norm_y_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "norm_y_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_2', 'norm_y_2', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_2", "key_type": 2}, {"key": "norm_y_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_3', 'attn', 'e_add', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_add", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_3', 'attn', 'e_add', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_add", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_3', 'attn', 'e_mul', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_mul", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_3', 'attn', 'e_mul', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_mul", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_3', 'attn', 'e_out', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_3', 'attn', 'e_out', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_3', 'attn', 'e_to_y', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_to_y", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_3', 'attn', 'e_to_y', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_to_y", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_3', 'attn', 'k', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "k", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_3', 'attn', 'k', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "k", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_3', 'attn', 'q', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "q", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_3', 'attn', 'q', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "q", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_3', 'attn', 'v', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "v", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_3', 'attn', 'v', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "v", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_3', 'attn', 'x_out', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_3', 'attn', 'x_out', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_3', 'attn', 'x_to_y', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_to_y", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_3', 'attn', 'x_to_y', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_to_y", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_3', 'attn', 'y_e_add', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_add", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_3', 'attn', 'y_e_add', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_add", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_3', 'attn', 'y_e_mul', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_mul", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_3', 'attn', 'y_e_mul', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_mul", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_3', 'attn', 'y_out_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_3', 'attn', 'y_out_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_3', 'attn', 'y_out_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_3', 'attn', 'y_out_2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_3', 'attn', 'y_x_add', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_add", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_3', 'attn', 'y_x_add', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_add", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_3', 'attn', 'y_x_mul', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_mul", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_3', 'attn', 'y_x_mul', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_mul", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_3', 'attn', 'y_y', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_y", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_3', 'attn', 'y_y', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_y", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_3', 'ffn_e_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "ffn_e_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_3', 'ffn_e_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "ffn_e_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_3', 'ffn_e_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "ffn_e_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_3', 'ffn_e_2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "ffn_e_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_3', 'ffn_x_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "ffn_x_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_3', 'ffn_x_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "ffn_x_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_3', 'ffn_x_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "ffn_x_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_3', 'ffn_x_2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "ffn_x_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_3', 'ffn_y_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "ffn_y_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_3', 'ffn_y_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "ffn_y_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_3', 'ffn_y_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "ffn_y_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_3', 'ffn_y_2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "ffn_y_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_3', 'film_l_e', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "film_l_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_3', 'film_l_e', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "film_l_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_3', 'film_l_x', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "film_l_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_3', 'film_l_x', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "film_l_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_3', 'film_t_e', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "film_t_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_3', 'film_t_e', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "film_t_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_3', 'film_t_x', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "film_t_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_3', 'film_t_x', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "film_t_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_3', 'film_y_e', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "film_y_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_3', 'film_y_e', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "film_y_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_3', 'film_y_x', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "film_y_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_3', 'film_y_x', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "film_y_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_3', 'norm_e_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "norm_e_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_3', 'norm_e_1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "norm_e_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_3', 'norm_e_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "norm_e_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_3', 'norm_e_2', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "norm_e_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_3', 'norm_x_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "norm_x_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_3', 'norm_x_1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "norm_x_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_3', 'norm_x_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "norm_x_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_3', 'norm_x_2', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "norm_x_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_3', 'norm_y_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "norm_y_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_3', 'norm_y_1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "norm_y_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_3', 'norm_y_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "norm_y_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_3', 'norm_y_2', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_3", "key_type": 2}, {"key": "norm_y_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_4', 'attn', 'e_add', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_add", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_4', 'attn', 'e_add', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_add", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_4', 'attn', 'e_mul', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_mul", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_4', 'attn', 'e_mul', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_mul", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_4', 'attn', 'e_out', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_4', 'attn', 'e_out', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_4', 'attn', 'e_to_y', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_to_y", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_4', 'attn', 'e_to_y', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_to_y", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_4', 'attn', 'k', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "k", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_4', 'attn', 'k', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "k", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_4', 'attn', 'q', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "q", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_4', 'attn', 'q', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "q", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_4', 'attn', 'v', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "v", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_4', 'attn', 'v', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "v", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_4', 'attn', 'x_out', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_4', 'attn', 'x_out', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_4', 'attn', 'x_to_y', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_to_y", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_4', 'attn', 'x_to_y', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_to_y", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_4', 'attn', 'y_e_add', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_add", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_4', 'attn', 'y_e_add', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_add", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_4', 'attn', 'y_e_mul', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_mul", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_4', 'attn', 'y_e_mul', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_mul", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_4', 'attn', 'y_out_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_4', 'attn', 'y_out_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_4', 'attn', 'y_out_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_4', 'attn', 'y_out_2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_4', 'attn', 'y_x_add', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_add", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_4', 'attn', 'y_x_add', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_add", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_4', 'attn', 'y_x_mul', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_mul", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_4', 'attn', 'y_x_mul', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_mul", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_4', 'attn', 'y_y', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_y", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_4', 'attn', 'y_y', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_y", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_4', 'ffn_e_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "ffn_e_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_4', 'ffn_e_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "ffn_e_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_4', 'ffn_e_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "ffn_e_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_4', 'ffn_e_2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "ffn_e_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_4', 'ffn_x_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "ffn_x_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_4', 'ffn_x_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "ffn_x_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_4', 'ffn_x_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "ffn_x_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_4', 'ffn_x_2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "ffn_x_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_4', 'ffn_y_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "ffn_y_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_4', 'ffn_y_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "ffn_y_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_4', 'ffn_y_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "ffn_y_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_4', 'ffn_y_2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "ffn_y_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_4', 'film_l_e', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "film_l_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_4', 'film_l_e', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "film_l_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_4', 'film_l_x', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "film_l_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_4', 'film_l_x', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "film_l_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_4', 'film_t_e', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "film_t_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_4', 'film_t_e', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "film_t_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_4', 'film_t_x', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "film_t_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_4', 'film_t_x', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "film_t_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_4', 'film_y_e', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "film_y_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_4', 'film_y_e', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "film_y_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_4', 'film_y_x', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "film_y_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_4', 'film_y_x', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "film_y_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_4', 'norm_e_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "norm_e_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_4', 'norm_e_1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "norm_e_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_4', 'norm_e_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "norm_e_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_4', 'norm_e_2', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "norm_e_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_4', 'norm_x_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "norm_x_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_4', 'norm_x_1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "norm_x_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_4', 'norm_x_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "norm_x_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_4', 'norm_x_2', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "norm_x_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_4', 'norm_y_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "norm_y_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_4', 'norm_y_1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "norm_y_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_4', 'norm_y_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "norm_y_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_4', 'norm_y_2', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_4", "key_type": 2}, {"key": "norm_y_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_5', 'attn', 'e_add', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_add", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_5', 'attn', 'e_add', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_add", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_5', 'attn', 'e_mul', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_mul", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_5', 'attn', 'e_mul', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_mul", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_5', 'attn', 'e_out', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_5', 'attn', 'e_out', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_5', 'attn', 'e_to_y', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_to_y", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_5', 'attn', 'e_to_y', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_to_y", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_5', 'attn', 'k', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "k", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_5', 'attn', 'k', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "k", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_5', 'attn', 'q', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "q", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_5', 'attn', 'q', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "q", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_5', 'attn', 'v', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "v", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_5', 'attn', 'v', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "v", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_5', 'attn', 'x_out', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_5', 'attn', 'x_out', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_5', 'attn', 'x_to_y', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_to_y", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_5', 'attn', 'x_to_y', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_to_y", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_5', 'attn', 'y_e_add', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_add", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_5', 'attn', 'y_e_add', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_add", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_5', 'attn', 'y_e_mul', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_mul", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_5', 'attn', 'y_e_mul', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_mul", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_5', 'attn', 'y_out_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_5', 'attn', 'y_out_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_5', 'attn', 'y_out_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_5', 'attn', 'y_out_2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_5', 'attn', 'y_x_add', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_add", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_5', 'attn', 'y_x_add', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_add", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_5', 'attn', 'y_x_mul', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_mul", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_5', 'attn', 'y_x_mul', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_mul", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_5', 'attn', 'y_y', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_y", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_5', 'attn', 'y_y', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_y", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_5', 'ffn_e_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "ffn_e_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_5', 'ffn_e_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "ffn_e_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_5', 'ffn_e_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "ffn_e_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_5', 'ffn_e_2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "ffn_e_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_5', 'ffn_x_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "ffn_x_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_5', 'ffn_x_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "ffn_x_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_5', 'ffn_x_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "ffn_x_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_5', 'ffn_x_2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "ffn_x_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_5', 'ffn_y_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "ffn_y_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_5', 'ffn_y_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "ffn_y_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_5', 'ffn_y_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "ffn_y_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_5', 'ffn_y_2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "ffn_y_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_5', 'film_l_e', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "film_l_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_5', 'film_l_e', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "film_l_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_5', 'film_l_x', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "film_l_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_5', 'film_l_x', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "film_l_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_5', 'film_t_e', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "film_t_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_5', 'film_t_e', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "film_t_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_5', 'film_t_x', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "film_t_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_5', 'film_t_x', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "film_t_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_5', 'film_y_e', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "film_y_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_5', 'film_y_e', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "film_y_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_5', 'film_y_x', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "film_y_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_5', 'film_y_x', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "film_y_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_5', 'norm_e_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "norm_e_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_5', 'norm_e_1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "norm_e_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_5', 'norm_e_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "norm_e_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_5', 'norm_e_2', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "norm_e_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_5', 'norm_x_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "norm_x_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_5', 'norm_x_1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "norm_x_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_5', 'norm_x_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "norm_x_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_5', 'norm_x_2', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "norm_x_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_5', 'norm_y_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "norm_y_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_5', 'norm_y_1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "norm_y_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_5', 'norm_y_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "norm_y_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_5', 'norm_y_2', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_5", "key_type": 2}, {"key": "norm_y_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_6', 'attn', 'e_add', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_add", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_6', 'attn', 'e_add', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_add", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_6', 'attn', 'e_mul', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_mul", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_6', 'attn', 'e_mul', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_mul", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_6', 'attn', 'e_out', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_6', 'attn', 'e_out', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_6', 'attn', 'e_to_y', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_to_y", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_6', 'attn', 'e_to_y', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_to_y", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_6', 'attn', 'k', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "k", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_6', 'attn', 'k', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "k", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_6', 'attn', 'q', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "q", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_6', 'attn', 'q', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "q", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_6', 'attn', 'v', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "v", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_6', 'attn', 'v', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "v", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_6', 'attn', 'x_out', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_6', 'attn', 'x_out', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_6', 'attn', 'x_to_y', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_to_y", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_6', 'attn', 'x_to_y', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_to_y", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_6', 'attn', 'y_e_add', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_add", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_6', 'attn', 'y_e_add', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_add", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_6', 'attn', 'y_e_mul', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_mul", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_6', 'attn', 'y_e_mul', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_mul", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_6', 'attn', 'y_out_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_6', 'attn', 'y_out_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_6', 'attn', 'y_out_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_6', 'attn', 'y_out_2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_6', 'attn', 'y_x_add', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_add", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_6', 'attn', 'y_x_add', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_add", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_6', 'attn', 'y_x_mul', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_mul", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_6', 'attn', 'y_x_mul', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_mul", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_6', 'attn', 'y_y', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_y", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_6', 'attn', 'y_y', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_y", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_6', 'ffn_e_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "ffn_e_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_6', 'ffn_e_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "ffn_e_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_6', 'ffn_e_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "ffn_e_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_6', 'ffn_e_2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "ffn_e_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_6', 'ffn_x_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "ffn_x_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_6', 'ffn_x_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "ffn_x_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_6', 'ffn_x_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "ffn_x_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_6', 'ffn_x_2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "ffn_x_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_6', 'ffn_y_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "ffn_y_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_6', 'ffn_y_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "ffn_y_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_6', 'ffn_y_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "ffn_y_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_6', 'ffn_y_2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "ffn_y_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_6', 'film_l_e', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "film_l_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_6', 'film_l_e', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "film_l_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_6', 'film_l_x', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "film_l_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_6', 'film_l_x', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "film_l_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_6', 'film_t_e', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "film_t_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_6', 'film_t_e', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "film_t_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_6', 'film_t_x', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "film_t_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_6', 'film_t_x', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "film_t_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_6', 'film_y_e', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "film_y_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_6', 'film_y_e', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "film_y_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_6', 'film_y_x', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "film_y_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_6', 'film_y_x', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "film_y_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_6', 'norm_e_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "norm_e_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_6', 'norm_e_1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "norm_e_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_6', 'norm_e_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "norm_e_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_6', 'norm_e_2', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "norm_e_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_6', 'norm_x_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "norm_x_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_6', 'norm_x_1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "norm_x_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_6', 'norm_x_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "norm_x_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_6', 'norm_x_2', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "norm_x_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_6', 'norm_y_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "norm_y_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_6', 'norm_y_1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "norm_y_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_6', 'norm_y_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "norm_y_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_6', 'norm_y_2', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_6", "key_type": 2}, {"key": "norm_y_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_7', 'attn', 'e_add', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_add", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_7', 'attn', 'e_add', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_add", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_7', 'attn', 'e_mul', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_mul", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_7', 'attn', 'e_mul', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_mul", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_7', 'attn', 'e_out', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_7', 'attn', 'e_out', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_7', 'attn', 'e_to_y', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_to_y", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_7', 'attn', 'e_to_y', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "e_to_y", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_7', 'attn', 'k', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "k", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_7', 'attn', 'k', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "k", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_7', 'attn', 'q', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "q", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_7', 'attn', 'q', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "q", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_7', 'attn', 'v', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "v", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_7', 'attn', 'v', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "v", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_7', 'attn', 'x_out', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_7', 'attn', 'x_out', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_7', 'attn', 'x_to_y', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_to_y", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_7', 'attn', 'x_to_y', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "x_to_y", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_7', 'attn', 'y_e_add', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_add", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_7', 'attn', 'y_e_add', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_add", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_7', 'attn', 'y_e_mul', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_mul", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_7', 'attn', 'y_e_mul', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_e_mul", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_7', 'attn', 'y_out_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_7', 'attn', 'y_out_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_7', 'attn', 'y_out_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_7', 'attn', 'y_out_2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_out_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_7', 'attn', 'y_x_add', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_add", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_7', 'attn', 'y_x_add', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_add", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_7', 'attn', 'y_x_mul', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_mul", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_7', 'attn', 'y_x_mul', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_x_mul", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_7', 'attn', 'y_y', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_y", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_7', 'attn', 'y_y', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "y_y", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_7', 'ffn_e_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "ffn_e_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_7', 'ffn_e_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "ffn_e_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_7', 'ffn_e_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "ffn_e_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_7', 'ffn_e_2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "ffn_e_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_7', 'ffn_x_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "ffn_x_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_7', 'ffn_x_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "ffn_x_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_7', 'ffn_x_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "ffn_x_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_7', 'ffn_x_2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "ffn_x_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_7', 'ffn_y_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "ffn_y_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_7', 'ffn_y_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "ffn_y_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_7', 'ffn_y_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "ffn_y_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_7', 'ffn_y_2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "ffn_y_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_7', 'film_l_e', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "film_l_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_7', 'film_l_e', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "film_l_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_7', 'film_l_x', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "film_l_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_7', 'film_l_x', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "film_l_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_7', 'film_t_e', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "film_t_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_7', 'film_t_e', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "film_t_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_7', 'film_t_x', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "film_t_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_7', 'film_t_x', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "film_t_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_7', 'film_y_e', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "film_y_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_7', 'film_y_e', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "film_y_e", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_7', 'film_y_x', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "film_y_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_7', 'film_y_x', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "film_y_x", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_7', 'norm_e_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "norm_e_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_7', 'norm_e_1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "norm_e_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_7', 'norm_e_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "norm_e_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_7', 'norm_e_2', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "norm_e_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_7', 'norm_x_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "norm_x_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_7', 'norm_x_1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "norm_x_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_7', 'norm_x_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "norm_x_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_7', 'norm_x_2', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "norm_x_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_7', 'norm_y_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "norm_y_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_7', 'norm_y_1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "norm_y_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_7', 'norm_y_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "norm_y_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'layer_7', 'norm_y_2', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "layer_7", "key_type": 2}, {"key": "norm_y_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'pos_emb', 'embedding')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "pos_emb", "key_type": 2}, {"key": "embedding", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'time_emb', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "time_emb", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'time_emb', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "time_emb", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'time_emb', 'Dense_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "time_emb", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'time_emb', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "time_emb", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'y_init', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "y_init", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'params', 'y_init', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "y_init", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('step',)": {"key_metadata": [{"key": "step", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('train_key',)": {"key_metadata": [{"key": "train_key", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}}, "use_zarr3": false, "store_array_data_equal_to_fill_value": true, "custom_metadata": null}