| {"tree_metadata": {"('step',)": {"key_metadata": [{"key": "step", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": []}}, "('params', 'transformer', 'h', '0', 'attn', 'c_attn', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "c_attn", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1536]}}, "('params', 'transformer', 'h', '0', 'attn', 'c_attn', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "c_attn", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1536, 512]}}, "('params', 'transformer', 'h', '0', 'attn', 'c_proj', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "c_proj", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'transformer', 'h', '0', 'attn', 'c_proj', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "c_proj", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('params', 'transformer', 'h', '0', 'ln_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "ln_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'transformer', 'h', '0', 'ln_1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "ln_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'transformer', 'h', '0', 'ln_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "ln_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'transformer', 'h', '0', 'ln_2', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "ln_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'transformer', 'h', '0', 'mlp', 'c_fc', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "mlp", "key_type": 2}, {"key": "c_fc", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [2048]}}, "('params', 'transformer', 'h', '0', 'mlp', 'c_fc', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "mlp", "key_type": 2}, {"key": "c_fc", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [2048, 512]}}, "('params', 'transformer', 'h', '0', 'mlp', 'c_proj', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "mlp", "key_type": 2}, {"key": "c_proj", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'transformer', 'h', '0', 'mlp', 'c_proj', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "mlp", "key_type": 2}, {"key": "c_proj", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 2048]}}, "('params', 'transformer', 'h', '1', 'attn', 'c_attn', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "c_attn", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1536]}}, "('params', 'transformer', 'h', '1', 'attn', 'c_attn', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "c_attn", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1536, 512]}}, "('params', 'transformer', 'h', '1', 'attn', 'c_proj', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "c_proj", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'transformer', 'h', '1', 'attn', 'c_proj', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "c_proj", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('params', 'transformer', 'h', '1', 'ln_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "ln_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'transformer', 'h', '1', 'ln_1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "ln_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'transformer', 'h', '1', 'ln_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "ln_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'transformer', 'h', '1', 'ln_2', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "ln_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'transformer', 'h', '1', 'mlp', 'c_fc', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "mlp", "key_type": 2}, {"key": "c_fc", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [2048]}}, "('params', 'transformer', 'h', '1', 'mlp', 'c_fc', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "mlp", "key_type": 2}, {"key": "c_fc", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [2048, 512]}}, "('params', 'transformer', 'h', '1', 'mlp', 'c_proj', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "mlp", "key_type": 2}, {"key": "c_proj", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'transformer', 'h', '1', 'mlp', 'c_proj', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "mlp", "key_type": 2}, {"key": "c_proj", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 2048]}}, "('params', 'transformer', 'h', '2', 'attn', 'c_attn', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "c_attn", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1536]}}, "('params', 'transformer', 'h', '2', 'attn', 'c_attn', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "c_attn", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1536, 512]}}, "('params', 'transformer', 'h', '2', 'attn', 'c_proj', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "c_proj", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'transformer', 'h', '2', 'attn', 'c_proj', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "c_proj", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('params', 'transformer', 'h', '2', 'ln_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "ln_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'transformer', 'h', '2', 'ln_1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "ln_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'transformer', 'h', '2', 'ln_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "ln_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'transformer', 'h', '2', 'ln_2', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "ln_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'transformer', 'h', '2', 'mlp', 'c_fc', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "mlp", "key_type": 2}, {"key": "c_fc", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [2048]}}, "('params', 'transformer', 'h', '2', 'mlp', 'c_fc', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "mlp", "key_type": 2}, {"key": "c_fc", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [2048, 512]}}, "('params', 'transformer', 'h', '2', 'mlp', 'c_proj', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "mlp", "key_type": 2}, {"key": "c_proj", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'transformer', 'h', '2', 'mlp', 'c_proj', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "mlp", "key_type": 2}, {"key": "c_proj", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 2048]}}, "('params', 'transformer', 'h', '3', 'attn', 'c_attn', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "c_attn", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1536]}}, "('params', 'transformer', 'h', '3', 'attn', 'c_attn', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "c_attn", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1536, 512]}}, "('params', 'transformer', 'h', '3', 'attn', 'c_proj', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "c_proj", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'transformer', 'h', '3', 'attn', 'c_proj', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "c_proj", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('params', 'transformer', 'h', '3', 'ln_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "3", "key_type": 2}, {"key": "ln_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'transformer', 'h', '3', 'ln_1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "3", "key_type": 2}, {"key": "ln_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'transformer', 'h', '3', 'ln_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "3", "key_type": 2}, {"key": "ln_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'transformer', 'h', '3', 'ln_2', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "3", "key_type": 2}, {"key": "ln_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'transformer', 'h', '3', 'mlp', 'c_fc', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "3", "key_type": 2}, {"key": "mlp", "key_type": 2}, {"key": "c_fc", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [2048]}}, "('params', 'transformer', 'h', '3', 'mlp', 'c_fc', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "3", "key_type": 2}, {"key": "mlp", "key_type": 2}, {"key": "c_fc", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [2048, 512]}}, "('params', 'transformer', 'h', '3', 'mlp', 'c_proj', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "3", "key_type": 2}, {"key": "mlp", "key_type": 2}, {"key": "c_proj", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'transformer', 'h', '3', 'mlp', 'c_proj', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "3", "key_type": 2}, {"key": "mlp", "key_type": 2}, {"key": "c_proj", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 2048]}}, "('params', 'transformer', 'h', '4', 'attn', 'c_attn', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "c_attn", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1536]}}, "('params', 'transformer', 'h', '4', 'attn', 'c_attn', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "c_attn", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1536, 512]}}, "('params', 'transformer', 'h', '4', 'attn', 'c_proj', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "c_proj", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'transformer', 'h', '4', 'attn', 'c_proj', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "c_proj", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('params', 'transformer', 'h', '4', 'ln_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "4", "key_type": 2}, {"key": "ln_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'transformer', 'h', '4', 'ln_1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "4", "key_type": 2}, {"key": "ln_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'transformer', 'h', '4', 'ln_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "4", "key_type": 2}, {"key": "ln_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'transformer', 'h', '4', 'ln_2', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "4", "key_type": 2}, {"key": "ln_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'transformer', 'h', '4', 'mlp', 'c_fc', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "4", "key_type": 2}, {"key": "mlp", "key_type": 2}, {"key": "c_fc", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [2048]}}, "('params', 'transformer', 'h', '4', 'mlp', 'c_fc', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "4", "key_type": 2}, {"key": "mlp", "key_type": 2}, {"key": "c_fc", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [2048, 512]}}, "('params', 'transformer', 'h', '4', 'mlp', 'c_proj', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "4", "key_type": 2}, {"key": "mlp", "key_type": 2}, {"key": "c_proj", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'transformer', 'h', '4', 'mlp', 'c_proj', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "4", "key_type": 2}, {"key": "mlp", "key_type": 2}, {"key": "c_proj", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 2048]}}, "('params', 'transformer', 'h', '5', 'attn', 'c_attn', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "c_attn", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1536]}}, "('params', 'transformer', 'h', '5', 'attn', 'c_attn', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "c_attn", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1536, 512]}}, "('params', 'transformer', 'h', '5', 'attn', 'c_proj', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "c_proj", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'transformer', 'h', '5', 'attn', 'c_proj', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "c_proj", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('params', 'transformer', 'h', '5', 'ln_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "5", "key_type": 2}, {"key": "ln_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'transformer', 'h', '5', 'ln_1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "5", "key_type": 2}, {"key": "ln_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'transformer', 'h', '5', 'ln_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "5", "key_type": 2}, {"key": "ln_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'transformer', 'h', '5', 'ln_2', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "5", "key_type": 2}, {"key": "ln_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'transformer', 'h', '5', 'mlp', 'c_fc', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "5", "key_type": 2}, {"key": "mlp", "key_type": 2}, {"key": "c_fc", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [2048]}}, "('params', 'transformer', 'h', '5', 'mlp', 'c_fc', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "5", "key_type": 2}, {"key": "mlp", "key_type": 2}, {"key": "c_fc", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [2048, 512]}}, "('params', 'transformer', 'h', '5', 'mlp', 'c_proj', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "5", "key_type": 2}, {"key": "mlp", "key_type": 2}, {"key": "c_proj", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'transformer', 'h', '5', 'mlp', 'c_proj', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "5", "key_type": 2}, {"key": "mlp", "key_type": 2}, {"key": "c_proj", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 2048]}}, "('params', 'transformer', 'ln_f', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "ln_f", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'transformer', 'ln_f', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "ln_f", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'transformer', 'wpe', 'embedding')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "wpe", "key_type": 2}, {"key": "embedding", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [30, 512]}}, "('params', 'transformer', 'wte', 'embedding')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "wte", "key_type": 2}, {"key": "embedding", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [19, 512]}}, "('opt_state', '0', 'count')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "count", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": []}}, "('opt_state', '0', 'mu', 'transformer', 'h', '0', 'attn', 'c_attn', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "c_attn", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1536]}}, "('opt_state', '0', 'mu', 'transformer', 'h', '0', 'attn', 'c_attn', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "c_attn", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1536, 512]}}, "('opt_state', '0', 'mu', 'transformer', 'h', '0', 'attn', 'c_proj', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "c_proj", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '0', 'mu', 'transformer', 'h', '0', 'attn', 'c_proj', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "c_proj", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '0', 'mu', 'transformer', 'h', '0', 'ln_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "ln_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '0', 'mu', 'transformer', 'h', '0', 'ln_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "ln_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '0', 'mu', 'transformer', 'h', '0', 'ln_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "ln_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '0', 'mu', 'transformer', 'h', '0', 'ln_2', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "ln_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '0', 'mu', 'transformer', 'h', '0', 'mlp', 'c_fc', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "mlp", "key_type": 2}, {"key": "c_fc", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [2048]}}, "('opt_state', '0', 'mu', 'transformer', 'h', '0', 'mlp', 'c_fc', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "mlp", "key_type": 2}, {"key": "c_fc", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [2048, 512]}}, "('opt_state', '0', 'mu', 'transformer', 'h', '0', 'mlp', 'c_proj', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "mlp", "key_type": 2}, {"key": "c_proj", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '0', 'mu', 'transformer', 'h', '0', 'mlp', 'c_proj', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "mlp", "key_type": 2}, {"key": "c_proj", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 2048]}}, "('opt_state', '0', 'mu', 'transformer', 'h', '1', 'attn', 'c_attn', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "c_attn", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1536]}}, "('opt_state', '0', 'mu', 'transformer', 'h', '1', 'attn', 'c_attn', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "c_attn", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1536, 512]}}, "('opt_state', '0', 'mu', 'transformer', 'h', '1', 'attn', 'c_proj', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "c_proj", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '0', 'mu', 'transformer', 'h', '1', 'attn', 'c_proj', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "c_proj", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '0', 'mu', 'transformer', 'h', '1', 'ln_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "ln_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '0', 'mu', 'transformer', 'h', '1', 'ln_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "ln_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '0', 'mu', 'transformer', 'h', '1', 'ln_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "ln_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '0', 'mu', 'transformer', 'h', '1', 'ln_2', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "ln_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '0', 'mu', 'transformer', 'h', '1', 'mlp', 'c_fc', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "mlp", "key_type": 2}, {"key": "c_fc", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [2048]}}, "('opt_state', '0', 'mu', 'transformer', 'h', '1', 'mlp', 'c_fc', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "mlp", "key_type": 2}, {"key": "c_fc", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [2048, 512]}}, "('opt_state', '0', 'mu', 'transformer', 'h', '1', 'mlp', 'c_proj', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "mlp", "key_type": 2}, {"key": "c_proj", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '0', 'mu', 'transformer', 'h', '1', 'mlp', 'c_proj', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "mlp", "key_type": 2}, {"key": "c_proj", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 2048]}}, "('opt_state', '0', 'mu', 'transformer', 'h', '2', 'attn', 'c_attn', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "c_attn", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1536]}}, "('opt_state', '0', 'mu', 'transformer', 'h', '2', 'attn', 'c_attn', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "c_attn", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1536, 512]}}, "('opt_state', '0', 'mu', 'transformer', 'h', '2', 'attn', 'c_proj', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "c_proj", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '0', 'mu', 'transformer', 'h', '2', 'attn', 'c_proj', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "c_proj", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '0', 'mu', 'transformer', 'h', '2', 'ln_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "ln_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '0', 'mu', 'transformer', 'h', '2', 'ln_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "ln_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '0', 'mu', 'transformer', 'h', '2', 'ln_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "ln_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '0', 'mu', 'transformer', 'h', '2', 'ln_2', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "ln_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '0', 'mu', 'transformer', 'h', '2', 'mlp', 'c_fc', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "mlp", "key_type": 2}, {"key": "c_fc", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [2048]}}, "('opt_state', '0', 'mu', 'transformer', 'h', '2', 'mlp', 'c_fc', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "mlp", "key_type": 2}, {"key": "c_fc", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [2048, 512]}}, "('opt_state', '0', 'mu', 'transformer', 'h', '2', 'mlp', 'c_proj', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "mlp", "key_type": 2}, {"key": "c_proj", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '0', 'mu', 'transformer', 'h', '2', 'mlp', 'c_proj', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "mlp", "key_type": 2}, {"key": "c_proj", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 2048]}}, "('opt_state', '0', 'mu', 'transformer', 'h', '3', 'attn', 'c_attn', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "c_attn", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1536]}}, "('opt_state', '0', 'mu', 'transformer', 'h', '3', 'attn', 'c_attn', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "c_attn", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1536, 512]}}, "('opt_state', '0', 'mu', 'transformer', 'h', '3', 'attn', 'c_proj', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "c_proj", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '0', 'mu', 'transformer', 'h', '3', 'attn', 'c_proj', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "c_proj", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '0', 'mu', 'transformer', 'h', '3', 'ln_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "3", "key_type": 2}, {"key": "ln_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '0', 'mu', 'transformer', 'h', '3', 'ln_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "3", "key_type": 2}, {"key": "ln_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '0', 'mu', 'transformer', 'h', '3', 'ln_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "3", "key_type": 2}, {"key": "ln_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '0', 'mu', 'transformer', 'h', '3', 'ln_2', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "3", "key_type": 2}, {"key": "ln_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '0', 'mu', 'transformer', 'h', '3', 'mlp', 'c_fc', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "3", "key_type": 2}, {"key": "mlp", "key_type": 2}, {"key": "c_fc", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [2048]}}, "('opt_state', '0', 'mu', 'transformer', 'h', '3', 'mlp', 'c_fc', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "3", "key_type": 2}, {"key": "mlp", "key_type": 2}, {"key": "c_fc", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [2048, 512]}}, "('opt_state', '0', 'mu', 'transformer', 'h', '3', 'mlp', 'c_proj', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "3", "key_type": 2}, {"key": "mlp", "key_type": 2}, {"key": "c_proj", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '0', 'mu', 'transformer', 'h', '3', 'mlp', 'c_proj', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "3", "key_type": 2}, {"key": "mlp", "key_type": 2}, {"key": "c_proj", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 2048]}}, "('opt_state', '0', 'mu', 'transformer', 'h', '4', 'attn', 'c_attn', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "c_attn", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1536]}}, "('opt_state', '0', 'mu', 'transformer', 'h', '4', 'attn', 'c_attn', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "c_attn", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1536, 512]}}, "('opt_state', '0', 'mu', 'transformer', 'h', '4', 'attn', 'c_proj', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "c_proj", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '0', 'mu', 'transformer', 'h', '4', 'attn', 'c_proj', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "c_proj", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '0', 'mu', 'transformer', 'h', '4', 'ln_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "4", "key_type": 2}, {"key": "ln_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '0', 'mu', 'transformer', 'h', '4', 'ln_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "4", "key_type": 2}, {"key": "ln_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '0', 'mu', 'transformer', 'h', '4', 'ln_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "4", "key_type": 2}, {"key": "ln_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '0', 'mu', 'transformer', 'h', '4', 'ln_2', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "4", "key_type": 2}, {"key": "ln_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '0', 'mu', 'transformer', 'h', '4', 'mlp', 'c_fc', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "4", "key_type": 2}, {"key": "mlp", "key_type": 2}, {"key": "c_fc", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [2048]}}, "('opt_state', '0', 'mu', 'transformer', 'h', '4', 'mlp', 'c_fc', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "4", "key_type": 2}, {"key": "mlp", "key_type": 2}, {"key": "c_fc", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [2048, 512]}}, "('opt_state', '0', 'mu', 'transformer', 'h', '4', 'mlp', 'c_proj', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "4", "key_type": 2}, {"key": "mlp", "key_type": 2}, {"key": "c_proj", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '0', 'mu', 'transformer', 'h', '4', 'mlp', 'c_proj', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "4", "key_type": 2}, {"key": "mlp", "key_type": 2}, {"key": "c_proj", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 2048]}}, "('opt_state', '0', 'mu', 'transformer', 'h', '5', 'attn', 'c_attn', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "c_attn", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1536]}}, "('opt_state', '0', 'mu', 'transformer', 'h', '5', 'attn', 'c_attn', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "c_attn", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1536, 512]}}, "('opt_state', '0', 'mu', 'transformer', 'h', '5', 'attn', 'c_proj', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "c_proj", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '0', 'mu', 'transformer', 'h', '5', 'attn', 'c_proj', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "c_proj", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '0', 'mu', 'transformer', 'h', '5', 'ln_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "5", "key_type": 2}, {"key": "ln_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '0', 'mu', 'transformer', 'h', '5', 'ln_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "5", "key_type": 2}, {"key": "ln_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '0', 'mu', 'transformer', 'h', '5', 'ln_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "5", "key_type": 2}, {"key": "ln_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '0', 'mu', 'transformer', 'h', '5', 'ln_2', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "5", "key_type": 2}, {"key": "ln_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '0', 'mu', 'transformer', 'h', '5', 'mlp', 'c_fc', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "5", "key_type": 2}, {"key": "mlp", "key_type": 2}, {"key": "c_fc", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [2048]}}, "('opt_state', '0', 'mu', 'transformer', 'h', '5', 'mlp', 'c_fc', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "5", "key_type": 2}, {"key": "mlp", "key_type": 2}, {"key": "c_fc", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [2048, 512]}}, "('opt_state', '0', 'mu', 'transformer', 'h', '5', 'mlp', 'c_proj', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "5", "key_type": 2}, {"key": "mlp", "key_type": 2}, {"key": "c_proj", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '0', 'mu', 'transformer', 'h', '5', 'mlp', 'c_proj', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "5", "key_type": 2}, {"key": "mlp", "key_type": 2}, {"key": "c_proj", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 2048]}}, "('opt_state', '0', 'mu', 'transformer', 'ln_f', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "ln_f", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '0', 'mu', 'transformer', 'ln_f', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "ln_f", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '0', 'mu', 'transformer', 'wpe', 'embedding')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "wpe", "key_type": 2}, {"key": "embedding", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [30, 512]}}, "('opt_state', '0', 'mu', 'transformer', 'wte', 'embedding')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "wte", "key_type": 2}, {"key": "embedding", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [19, 512]}}, "('opt_state', '0', 'nu', 'transformer', 'h', '0', 'attn', 'c_attn', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "c_attn", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1536]}}, "('opt_state', '0', 'nu', 'transformer', 'h', '0', 'attn', 'c_attn', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "c_attn", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1536, 512]}}, "('opt_state', '0', 'nu', 'transformer', 'h', '0', 'attn', 'c_proj', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "c_proj", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '0', 'nu', 'transformer', 'h', '0', 'attn', 'c_proj', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "c_proj", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '0', 'nu', 'transformer', 'h', '0', 'ln_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "ln_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '0', 'nu', 'transformer', 'h', '0', 'ln_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "ln_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '0', 'nu', 'transformer', 'h', '0', 'ln_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "ln_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '0', 'nu', 'transformer', 'h', '0', 'ln_2', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "ln_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '0', 'nu', 'transformer', 'h', '0', 'mlp', 'c_fc', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "mlp", "key_type": 2}, {"key": "c_fc", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [2048]}}, "('opt_state', '0', 'nu', 'transformer', 'h', '0', 'mlp', 'c_fc', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "mlp", "key_type": 2}, {"key": "c_fc", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [2048, 512]}}, "('opt_state', '0', 'nu', 'transformer', 'h', '0', 'mlp', 'c_proj', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "mlp", "key_type": 2}, {"key": "c_proj", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '0', 'nu', 'transformer', 'h', '0', 'mlp', 'c_proj', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "0", "key_type": 2}, {"key": "mlp", "key_type": 2}, {"key": "c_proj", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 2048]}}, "('opt_state', '0', 'nu', 'transformer', 'h', '1', 'attn', 'c_attn', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "c_attn", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1536]}}, "('opt_state', '0', 'nu', 'transformer', 'h', '1', 'attn', 'c_attn', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "c_attn", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1536, 512]}}, "('opt_state', '0', 'nu', 'transformer', 'h', '1', 'attn', 'c_proj', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "c_proj", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '0', 'nu', 'transformer', 'h', '1', 'attn', 'c_proj', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "c_proj", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '0', 'nu', 'transformer', 'h', '1', 'ln_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "ln_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '0', 'nu', 'transformer', 'h', '1', 'ln_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "ln_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '0', 'nu', 'transformer', 'h', '1', 'ln_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "ln_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '0', 'nu', 'transformer', 'h', '1', 'ln_2', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "ln_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '0', 'nu', 'transformer', 'h', '1', 'mlp', 'c_fc', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "mlp", "key_type": 2}, {"key": "c_fc", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [2048]}}, "('opt_state', '0', 'nu', 'transformer', 'h', '1', 'mlp', 'c_fc', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "mlp", "key_type": 2}, {"key": "c_fc", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [2048, 512]}}, "('opt_state', '0', 'nu', 'transformer', 'h', '1', 'mlp', 'c_proj', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "mlp", "key_type": 2}, {"key": "c_proj", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '0', 'nu', 'transformer', 'h', '1', 'mlp', 'c_proj', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "1", "key_type": 2}, {"key": "mlp", "key_type": 2}, {"key": "c_proj", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 2048]}}, "('opt_state', '0', 'nu', 'transformer', 'h', '2', 'attn', 'c_attn', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "c_attn", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1536]}}, "('opt_state', '0', 'nu', 'transformer', 'h', '2', 'attn', 'c_attn', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "c_attn", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1536, 512]}}, "('opt_state', '0', 'nu', 'transformer', 'h', '2', 'attn', 'c_proj', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "c_proj", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '0', 'nu', 'transformer', 'h', '2', 'attn', 'c_proj', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "c_proj", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '0', 'nu', 'transformer', 'h', '2', 'ln_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "ln_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '0', 'nu', 'transformer', 'h', '2', 'ln_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "ln_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '0', 'nu', 'transformer', 'h', '2', 'ln_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "ln_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '0', 'nu', 'transformer', 'h', '2', 'ln_2', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "ln_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '0', 'nu', 'transformer', 'h', '2', 'mlp', 'c_fc', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "mlp", "key_type": 2}, {"key": "c_fc", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [2048]}}, "('opt_state', '0', 'nu', 'transformer', 'h', '2', 'mlp', 'c_fc', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "mlp", "key_type": 2}, {"key": "c_fc", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [2048, 512]}}, "('opt_state', '0', 'nu', 'transformer', 'h', '2', 'mlp', 'c_proj', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "mlp", "key_type": 2}, {"key": "c_proj", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '0', 'nu', 'transformer', 'h', '2', 'mlp', 'c_proj', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "2", "key_type": 2}, {"key": "mlp", "key_type": 2}, {"key": "c_proj", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 2048]}}, "('opt_state', '0', 'nu', 'transformer', 'h', '3', 'attn', 'c_attn', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "c_attn", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1536]}}, "('opt_state', '0', 'nu', 'transformer', 'h', '3', 'attn', 'c_attn', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "c_attn", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1536, 512]}}, "('opt_state', '0', 'nu', 'transformer', 'h', '3', 'attn', 'c_proj', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "c_proj", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '0', 'nu', 'transformer', 'h', '3', 'attn', 'c_proj', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "3", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "c_proj", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '0', 'nu', 'transformer', 'h', '3', 'ln_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "3", "key_type": 2}, {"key": "ln_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '0', 'nu', 'transformer', 'h', '3', 'ln_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "3", "key_type": 2}, {"key": "ln_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '0', 'nu', 'transformer', 'h', '3', 'ln_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "3", "key_type": 2}, {"key": "ln_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '0', 'nu', 'transformer', 'h', '3', 'ln_2', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "3", "key_type": 2}, {"key": "ln_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '0', 'nu', 'transformer', 'h', '3', 'mlp', 'c_fc', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "3", "key_type": 2}, {"key": "mlp", "key_type": 2}, {"key": "c_fc", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [2048]}}, "('opt_state', '0', 'nu', 'transformer', 'h', '3', 'mlp', 'c_fc', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "3", "key_type": 2}, {"key": "mlp", "key_type": 2}, {"key": "c_fc", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [2048, 512]}}, "('opt_state', '0', 'nu', 'transformer', 'h', '3', 'mlp', 'c_proj', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "3", "key_type": 2}, {"key": "mlp", "key_type": 2}, {"key": "c_proj", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '0', 'nu', 'transformer', 'h', '3', 'mlp', 'c_proj', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "3", "key_type": 2}, {"key": "mlp", "key_type": 2}, {"key": "c_proj", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 2048]}}, "('opt_state', '0', 'nu', 'transformer', 'h', '4', 'attn', 'c_attn', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "c_attn", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1536]}}, "('opt_state', '0', 'nu', 'transformer', 'h', '4', 'attn', 'c_attn', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "c_attn", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1536, 512]}}, "('opt_state', '0', 'nu', 'transformer', 'h', '4', 'attn', 'c_proj', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "c_proj", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '0', 'nu', 'transformer', 'h', '4', 'attn', 'c_proj', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "4", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "c_proj", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '0', 'nu', 'transformer', 'h', '4', 'ln_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "4", "key_type": 2}, {"key": "ln_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '0', 'nu', 'transformer', 'h', '4', 'ln_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "4", "key_type": 2}, {"key": "ln_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '0', 'nu', 'transformer', 'h', '4', 'ln_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "4", "key_type": 2}, {"key": "ln_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '0', 'nu', 'transformer', 'h', '4', 'ln_2', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "4", "key_type": 2}, {"key": "ln_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '0', 'nu', 'transformer', 'h', '4', 'mlp', 'c_fc', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "4", "key_type": 2}, {"key": "mlp", "key_type": 2}, {"key": "c_fc", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [2048]}}, "('opt_state', '0', 'nu', 'transformer', 'h', '4', 'mlp', 'c_fc', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "4", "key_type": 2}, {"key": "mlp", "key_type": 2}, {"key": "c_fc", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [2048, 512]}}, "('opt_state', '0', 'nu', 'transformer', 'h', '4', 'mlp', 'c_proj', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "4", "key_type": 2}, {"key": "mlp", "key_type": 2}, {"key": "c_proj", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '0', 'nu', 'transformer', 'h', '4', 'mlp', 'c_proj', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "4", "key_type": 2}, {"key": "mlp", "key_type": 2}, {"key": "c_proj", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 2048]}}, "('opt_state', '0', 'nu', 'transformer', 'h', '5', 'attn', 'c_attn', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "c_attn", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1536]}}, "('opt_state', '0', 'nu', 'transformer', 'h', '5', 'attn', 'c_attn', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "c_attn", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1536, 512]}}, "('opt_state', '0', 'nu', 'transformer', 'h', '5', 'attn', 'c_proj', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "c_proj", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '0', 'nu', 'transformer', 'h', '5', 'attn', 'c_proj', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "5", "key_type": 2}, {"key": "attn", "key_type": 2}, {"key": "c_proj", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '0', 'nu', 'transformer', 'h', '5', 'ln_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "5", "key_type": 2}, {"key": "ln_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '0', 'nu', 'transformer', 'h', '5', 'ln_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "5", "key_type": 2}, {"key": "ln_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '0', 'nu', 'transformer', 'h', '5', 'ln_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "5", "key_type": 2}, {"key": "ln_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '0', 'nu', 'transformer', 'h', '5', 'ln_2', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "5", "key_type": 2}, {"key": "ln_2", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '0', 'nu', 'transformer', 'h', '5', 'mlp', 'c_fc', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "5", "key_type": 2}, {"key": "mlp", "key_type": 2}, {"key": "c_fc", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [2048]}}, "('opt_state', '0', 'nu', 'transformer', 'h', '5', 'mlp', 'c_fc', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "5", "key_type": 2}, {"key": "mlp", "key_type": 2}, {"key": "c_fc", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [2048, 512]}}, "('opt_state', '0', 'nu', 'transformer', 'h', '5', 'mlp', 'c_proj', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "5", "key_type": 2}, {"key": "mlp", "key_type": 2}, {"key": "c_proj", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '0', 'nu', 'transformer', 'h', '5', 'mlp', 'c_proj', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "h", "key_type": 2}, {"key": "5", "key_type": 2}, {"key": "mlp", "key_type": 2}, {"key": "c_proj", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 2048]}}, "('opt_state', '0', 'nu', 'transformer', 'ln_f', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "ln_f", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '0', 'nu', 'transformer', 'ln_f', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "ln_f", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '0', 'nu', 'transformer', 'wpe', 'embedding')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "wpe", "key_type": 2}, {"key": "embedding", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [30, 512]}}, "('opt_state', '0', 'nu', 'transformer', 'wte', 'embedding')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "transformer", "key_type": 2}, {"key": "wte", "key_type": 2}, {"key": "embedding", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [19, 512]}}, "('opt_state', '1')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}], "value_metadata": {"value_type": "None", "skip_deserialize": true}}, "('opt_state', '2', 'count')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "2", "key_type": 1}, {"key": "count", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": []}}, "('dropout_rng',)": {"key_metadata": [{"key": "dropout_rng", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [4, 2]}}}, "use_zarr3": false, "store_array_data_equal_to_fill_value": true, "custom_metadata": null} |