diff --git "a/default/_METADATA" "b/default/_METADATA" new file mode 100644--- /dev/null +++ "b/default/_METADATA" @@ -0,0 +1 @@ +{"tree_metadata": {"('step',)": {"key_metadata": [{"key": "step", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": []}}, "('params', 'acc', 'halt_net', 'layers_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "acc", "key_type": 2}, {"key": "halt_net", "key_type": 2}, {"key": "layers_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [24]}}, "('params', 'acc', 'halt_net', 'layers_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "acc", "key_type": 2}, {"key": "halt_net", "key_type": 2}, {"key": "layers_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 192]}}, "('params', 'acc', 'halt_net', 'layers_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "acc", "key_type": 2}, {"key": "halt_net", "key_type": 2}, {"key": "layers_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1]}}, "('params', 'acc', 'halt_net', 'layers_2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "acc", "key_type": 2}, {"key": "halt_net", "key_type": 2}, {"key": "layers_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [24, 1]}}, "('params', 'acc', 'loop_embed', 'embedding')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "acc", "key_type": 2}, {"key": "loop_embed", "key_type": 2}, {"key": "embedding", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [4, 768]}}, "('params', 'acc', 'state_gate', 'layers_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "acc", "key_type": 2}, {"key": "state_gate", "key_type": 2}, {"key": "layers_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('params', 'acc', 'state_gate', 'layers_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "acc", "key_type": 2}, {"key": "state_gate", "key_type": 2}, {"key": "layers_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [192, 768]}}, "('params', 'acc', 'state_norm', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "acc", "key_type": 2}, {"key": "state_norm", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('params', 'acc', 'state_norm', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "acc", "key_type": 2}, {"key": "state_norm", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('params', 'acc', 'state_transform', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "acc", "key_type": 2}, {"key": "state_transform", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('params', 'acc', 'state_transform', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "acc", "key_type": 2}, {"key": "state_transform", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 768]}}, "('params', 'controller', 'embedding', 'embedding')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "embedding", "key_type": 2}, {"key": "embedding", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [50257, 96]}}, "('params', 'controller', 'final_norm', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "final_norm", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('params', 'controller', 'final_norm', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "final_norm", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('params', 'controller', 'layers_0', 'FlashCausalSelfAttention_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_0", "key_type": 2}, {"key": "FlashCausalSelfAttention_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 2304]}}, "('params', 'controller', 'layers_0', 'FlashCausalSelfAttention_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_0", "key_type": 2}, {"key": "FlashCausalSelfAttention_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 768]}}, "('params', 'controller', 'layers_0', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_0", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('params', 'controller', 'layers_0', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_0", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('params', 'controller', 'layers_0', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_0", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('params', 'controller', 'layers_0', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_0", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('params', 'controller', 'layers_0', 'TinyFFN_0', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_0", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [192]}}, "('params', 'controller', 'layers_0', 'TinyFFN_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_0", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 1536]}}, "('params', 'controller', 'layers_0', 'TinyFFN_0', 'Dense_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_0", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('params', 'controller', 'layers_0', 'TinyFFN_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_0", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [192, 768]}}, "('params', 'controller', 'layers_1', 'FlashCausalSelfAttention_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_1", "key_type": 2}, {"key": "FlashCausalSelfAttention_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 2304]}}, "('params', 'controller', 'layers_1', 'FlashCausalSelfAttention_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_1", "key_type": 2}, {"key": "FlashCausalSelfAttention_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 768]}}, "('params', 'controller', 'layers_1', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_1", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('params', 'controller', 'layers_1', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_1", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('params', 'controller', 'layers_1', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_1", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('params', 'controller', 'layers_1', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_1", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('params', 'controller', 'layers_1', 'TinyFFN_0', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_1", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [192]}}, "('params', 'controller', 'layers_1', 'TinyFFN_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_1", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 1536]}}, "('params', 'controller', 'layers_1', 'TinyFFN_0', 'Dense_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_1", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('params', 'controller', 'layers_1', 'TinyFFN_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_1", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [192, 768]}}, "('params', 'controller', 'layers_10', 'FlashCausalSelfAttention_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_10", "key_type": 2}, {"key": "FlashCausalSelfAttention_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 2304]}}, "('params', 'controller', 'layers_10', 'FlashCausalSelfAttention_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_10", "key_type": 2}, {"key": "FlashCausalSelfAttention_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 768]}}, "('params', 'controller', 'layers_10', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_10", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('params', 'controller', 'layers_10', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_10", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('params', 'controller', 'layers_10', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_10", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('params', 'controller', 'layers_10', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_10", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('params', 'controller', 'layers_10', 'TinyFFN_0', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_10", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [192]}}, "('params', 'controller', 'layers_10', 'TinyFFN_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_10", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 1536]}}, "('params', 'controller', 'layers_10', 'TinyFFN_0', 'Dense_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_10", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('params', 'controller', 'layers_10', 'TinyFFN_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_10", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [192, 768]}}, "('params', 'controller', 'layers_11', 'FlashCausalSelfAttention_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_11", "key_type": 2}, {"key": "FlashCausalSelfAttention_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 2304]}}, "('params', 'controller', 'layers_11', 'FlashCausalSelfAttention_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_11", "key_type": 2}, {"key": "FlashCausalSelfAttention_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 768]}}, "('params', 'controller', 'layers_11', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_11", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('params', 'controller', 'layers_11', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_11", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('params', 'controller', 'layers_11', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_11", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('params', 'controller', 'layers_11', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_11", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('params', 'controller', 'layers_11', 'TinyFFN_0', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_11", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [192]}}, "('params', 'controller', 'layers_11', 'TinyFFN_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_11", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 1536]}}, "('params', 'controller', 'layers_11', 'TinyFFN_0', 'Dense_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_11", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('params', 'controller', 'layers_11', 'TinyFFN_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_11", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [192, 768]}}, "('params', 'controller', 'layers_2', 'FlashCausalSelfAttention_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_2", "key_type": 2}, {"key": "FlashCausalSelfAttention_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 2304]}}, "('params', 'controller', 'layers_2', 'FlashCausalSelfAttention_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_2", "key_type": 2}, {"key": "FlashCausalSelfAttention_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 768]}}, "('params', 'controller', 'layers_2', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_2", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('params', 'controller', 'layers_2', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_2", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('params', 'controller', 'layers_2', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_2", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('params', 'controller', 'layers_2', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_2", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('params', 'controller', 'layers_2', 'TinyFFN_0', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_2", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [192]}}, "('params', 'controller', 'layers_2', 'TinyFFN_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_2", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 1536]}}, "('params', 'controller', 'layers_2', 'TinyFFN_0', 'Dense_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_2", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('params', 'controller', 'layers_2', 'TinyFFN_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_2", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [192, 768]}}, "('params', 'controller', 'layers_3', 'FlashCausalSelfAttention_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_3", "key_type": 2}, {"key": "FlashCausalSelfAttention_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 2304]}}, "('params', 'controller', 'layers_3', 'FlashCausalSelfAttention_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_3", "key_type": 2}, {"key": "FlashCausalSelfAttention_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 768]}}, "('params', 'controller', 'layers_3', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_3", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('params', 'controller', 'layers_3', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_3", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('params', 'controller', 'layers_3', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_3", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('params', 'controller', 'layers_3', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_3", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('params', 'controller', 'layers_3', 'TinyFFN_0', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_3", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [192]}}, "('params', 'controller', 'layers_3', 'TinyFFN_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_3", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 1536]}}, "('params', 'controller', 'layers_3', 'TinyFFN_0', 'Dense_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_3", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('params', 'controller', 'layers_3', 'TinyFFN_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_3", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [192, 768]}}, "('params', 'controller', 'layers_4', 'FlashCausalSelfAttention_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_4", "key_type": 2}, {"key": "FlashCausalSelfAttention_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 2304]}}, "('params', 'controller', 'layers_4', 'FlashCausalSelfAttention_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_4", "key_type": 2}, {"key": "FlashCausalSelfAttention_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 768]}}, "('params', 'controller', 'layers_4', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_4", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('params', 'controller', 'layers_4', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_4", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('params', 'controller', 'layers_4', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_4", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('params', 'controller', 'layers_4', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_4", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('params', 'controller', 'layers_4', 'TinyFFN_0', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_4", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [192]}}, "('params', 'controller', 'layers_4', 'TinyFFN_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_4", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 1536]}}, "('params', 'controller', 'layers_4', 'TinyFFN_0', 'Dense_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_4", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('params', 'controller', 'layers_4', 'TinyFFN_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_4", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [192, 768]}}, "('params', 'controller', 'layers_5', 'FlashCausalSelfAttention_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_5", "key_type": 2}, {"key": "FlashCausalSelfAttention_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 2304]}}, "('params', 'controller', 'layers_5', 'FlashCausalSelfAttention_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_5", "key_type": 2}, {"key": "FlashCausalSelfAttention_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 768]}}, "('params', 'controller', 'layers_5', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_5", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('params', 'controller', 'layers_5', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_5", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('params', 'controller', 'layers_5', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_5", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('params', 'controller', 'layers_5', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_5", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('params', 'controller', 'layers_5', 'TinyFFN_0', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_5", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [192]}}, "('params', 'controller', 'layers_5', 'TinyFFN_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_5", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 1536]}}, "('params', 'controller', 'layers_5', 'TinyFFN_0', 'Dense_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_5", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('params', 'controller', 'layers_5', 'TinyFFN_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_5", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [192, 768]}}, "('params', 'controller', 'layers_6', 'FlashCausalSelfAttention_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_6", "key_type": 2}, {"key": "FlashCausalSelfAttention_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 2304]}}, "('params', 'controller', 'layers_6', 'FlashCausalSelfAttention_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_6", "key_type": 2}, {"key": "FlashCausalSelfAttention_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 768]}}, "('params', 'controller', 'layers_6', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_6", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('params', 'controller', 'layers_6', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_6", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('params', 'controller', 'layers_6', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_6", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('params', 'controller', 'layers_6', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_6", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('params', 'controller', 'layers_6', 'TinyFFN_0', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_6", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [192]}}, "('params', 'controller', 'layers_6', 'TinyFFN_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_6", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 1536]}}, "('params', 'controller', 'layers_6', 'TinyFFN_0', 'Dense_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_6", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('params', 'controller', 'layers_6', 'TinyFFN_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_6", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [192, 768]}}, "('params', 'controller', 'layers_7', 'FlashCausalSelfAttention_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_7", "key_type": 2}, {"key": "FlashCausalSelfAttention_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 2304]}}, "('params', 'controller', 'layers_7', 'FlashCausalSelfAttention_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_7", "key_type": 2}, {"key": "FlashCausalSelfAttention_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 768]}}, "('params', 'controller', 'layers_7', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_7", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('params', 'controller', 'layers_7', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_7", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('params', 'controller', 'layers_7', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_7", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('params', 'controller', 'layers_7', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_7", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('params', 'controller', 'layers_7', 'TinyFFN_0', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_7", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [192]}}, "('params', 'controller', 'layers_7', 'TinyFFN_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_7", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 1536]}}, "('params', 'controller', 'layers_7', 'TinyFFN_0', 'Dense_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_7", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('params', 'controller', 'layers_7', 'TinyFFN_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_7", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [192, 768]}}, "('params', 'controller', 'layers_8', 'FlashCausalSelfAttention_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_8", "key_type": 2}, {"key": "FlashCausalSelfAttention_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 2304]}}, "('params', 'controller', 'layers_8', 'FlashCausalSelfAttention_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_8", "key_type": 2}, {"key": "FlashCausalSelfAttention_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 768]}}, "('params', 'controller', 'layers_8', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_8", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('params', 'controller', 'layers_8', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_8", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('params', 'controller', 'layers_8', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_8", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('params', 'controller', 'layers_8', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_8", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('params', 'controller', 'layers_8', 'TinyFFN_0', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_8", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [192]}}, "('params', 'controller', 'layers_8', 'TinyFFN_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_8", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 1536]}}, "('params', 'controller', 'layers_8', 'TinyFFN_0', 'Dense_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_8", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('params', 'controller', 'layers_8', 'TinyFFN_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_8", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [192, 768]}}, "('params', 'controller', 'layers_9', 'FlashCausalSelfAttention_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_9", "key_type": 2}, {"key": "FlashCausalSelfAttention_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 2304]}}, "('params', 'controller', 'layers_9', 'FlashCausalSelfAttention_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_9", "key_type": 2}, {"key": "FlashCausalSelfAttention_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 768]}}, "('params', 'controller', 'layers_9', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_9", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('params', 'controller', 'layers_9', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_9", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('params', 'controller', 'layers_9', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_9", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('params', 'controller', 'layers_9', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_9", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('params', 'controller', 'layers_9', 'TinyFFN_0', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_9", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [192]}}, "('params', 'controller', 'layers_9', 'TinyFFN_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_9", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 1536]}}, "('params', 'controller', 'layers_9', 'TinyFFN_0', 'Dense_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_9", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('params', 'controller', 'layers_9', 'TinyFFN_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_9", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [192, 768]}}, "('params', 'controller', 'lm_head', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "lm_head", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 50257]}}, "('params', 'controller', 'pos_encoding', 'embedding')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "pos_encoding", "key_type": 2}, {"key": "embedding", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [128, 768]}}, "('params', 'indexer', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "indexer", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 1]}}, "('params', 'indexer', 'Dense_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "indexer", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('params', 'indexer', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "indexer", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 768]}}, "('params', 'indexer', 'Dense_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "indexer", "key_type": 2}, {"key": "Dense_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [48]}}, "('params', 'indexer', 'Dense_2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "indexer", "key_type": 2}, {"key": "Dense_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 384]}}, "('params', 'indexer', 'Dense_3', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "indexer", "key_type": 2}, {"key": "Dense_3", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1]}}, "('params', 'indexer', 'Dense_3', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "indexer", "key_type": 2}, {"key": "Dense_3", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [48, 1]}}, "('params', 'indexer', 'Dense_4', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "indexer", "key_type": 2}, {"key": "Dense_4", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1]}}, "('params', 'indexer', 'Dense_4', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "indexer", "key_type": 2}, {"key": "Dense_4", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [48, 1]}}, "('params', 'pool', 'params_storage')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "pool", "key_type": 2}, {"key": "params_storage", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [32768, 768]}}, "('params', 'retrieval_integrator', 'layers_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "retrieval_integrator", "key_type": 2}, {"key": "layers_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('params', 'retrieval_integrator', 'layers_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "retrieval_integrator", "key_type": 2}, {"key": "layers_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [192, 768]}}, "('params', 'retrieval_integrator', 'layers_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "retrieval_integrator", "key_type": 2}, {"key": "layers_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('params', 'retrieval_integrator', 'layers_2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "retrieval_integrator", "key_type": 2}, {"key": "layers_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 768]}}, "('params', 'retrieval_integrator', 'layers_3', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "retrieval_integrator", "key_type": 2}, {"key": "layers_3", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('params', 'retrieval_integrator', 'layers_3', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "retrieval_integrator", "key_type": 2}, {"key": "layers_3", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '0')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}], "value_metadata": {"value_type": "None", "skip_deserialize": true}}, "('opt_state', '1', '0', 'count')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "count", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": []}}, "('opt_state', '1', '0', 'mu', 'acc', 'halt_net', 'layers_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "acc", "key_type": 2}, {"key": "halt_net", "key_type": 2}, {"key": "layers_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [24]}}, "('opt_state', '1', '0', 'mu', 'acc', 'halt_net', 'layers_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "acc", "key_type": 2}, {"key": "halt_net", "key_type": 2}, {"key": "layers_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 192]}}, "('opt_state', '1', '0', 'mu', 'acc', 'halt_net', 'layers_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "acc", "key_type": 2}, {"key": "halt_net", "key_type": 2}, {"key": "layers_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1]}}, "('opt_state', '1', '0', 'mu', 'acc', 'halt_net', 'layers_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "acc", "key_type": 2}, {"key": "halt_net", "key_type": 2}, {"key": "layers_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [24, 1]}}, "('opt_state', '1', '0', 'mu', 'acc', 'loop_embed', 'embedding')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "acc", "key_type": 2}, {"key": "loop_embed", "key_type": 2}, {"key": "embedding", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [4, 768]}}, "('opt_state', '1', '0', 'mu', 'acc', 'state_gate', 'layers_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "acc", "key_type": 2}, {"key": "state_gate", "key_type": 2}, {"key": "layers_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'mu', 'acc', 'state_gate', 'layers_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "acc", "key_type": 2}, {"key": "state_gate", "key_type": 2}, {"key": "layers_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [192, 768]}}, "('opt_state', '1', '0', 'mu', 'acc', 'state_norm', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "acc", "key_type": 2}, {"key": "state_norm", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'mu', 'acc', 'state_norm', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "acc", "key_type": 2}, {"key": "state_norm", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'mu', 'acc', 'state_transform', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "acc", "key_type": 2}, {"key": "state_transform", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'mu', 'acc', 'state_transform', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "acc", "key_type": 2}, {"key": "state_transform", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 768]}}, "('opt_state', '1', '0', 'mu', 'controller', 'embedding', 'embedding')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "embedding", "key_type": 2}, {"key": "embedding", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [50257, 96]}}, "('opt_state', '1', '0', 'mu', 'controller', 'final_norm', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "final_norm", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'mu', 'controller', 'final_norm', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "final_norm", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_0', 'FlashCausalSelfAttention_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_0", "key_type": 2}, {"key": "FlashCausalSelfAttention_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 2304]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_0', 'FlashCausalSelfAttention_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_0", "key_type": 2}, {"key": "FlashCausalSelfAttention_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 768]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_0', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_0", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_0', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_0", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_0', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_0", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_0', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_0", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_0', 'TinyFFN_0', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_0", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [192]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_0', 'TinyFFN_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_0", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 1536]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_0', 'TinyFFN_0', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_0", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_0', 'TinyFFN_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_0", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [192, 768]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_1', 'FlashCausalSelfAttention_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_1", "key_type": 2}, {"key": "FlashCausalSelfAttention_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 2304]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_1', 'FlashCausalSelfAttention_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_1", "key_type": 2}, {"key": "FlashCausalSelfAttention_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 768]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_1', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_1", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_1', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_1", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_1', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_1", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_1', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_1", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_1', 'TinyFFN_0', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_1", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [192]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_1', 'TinyFFN_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_1", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 1536]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_1', 'TinyFFN_0', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_1", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_1', 'TinyFFN_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_1", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [192, 768]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_10', 'FlashCausalSelfAttention_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_10", "key_type": 2}, {"key": "FlashCausalSelfAttention_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 2304]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_10', 'FlashCausalSelfAttention_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_10", "key_type": 2}, {"key": "FlashCausalSelfAttention_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 768]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_10', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_10", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_10', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_10", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_10', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_10", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_10', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_10", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_10', 'TinyFFN_0', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_10", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [192]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_10', 'TinyFFN_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_10", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 1536]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_10', 'TinyFFN_0', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_10", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_10', 'TinyFFN_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_10", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [192, 768]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_11', 'FlashCausalSelfAttention_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_11", "key_type": 2}, {"key": "FlashCausalSelfAttention_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 2304]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_11', 'FlashCausalSelfAttention_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_11", "key_type": 2}, {"key": "FlashCausalSelfAttention_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 768]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_11', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_11", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_11', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_11", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_11', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_11", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_11', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_11", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_11', 'TinyFFN_0', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_11", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [192]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_11', 'TinyFFN_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_11", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 1536]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_11', 'TinyFFN_0', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_11", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_11', 'TinyFFN_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_11", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [192, 768]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_2', 'FlashCausalSelfAttention_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_2", "key_type": 2}, {"key": "FlashCausalSelfAttention_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 2304]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_2', 'FlashCausalSelfAttention_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_2", "key_type": 2}, {"key": "FlashCausalSelfAttention_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 768]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_2', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_2", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_2', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_2", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_2', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_2", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_2', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_2", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_2', 'TinyFFN_0', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_2", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [192]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_2', 'TinyFFN_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_2", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 1536]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_2', 'TinyFFN_0', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_2", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_2', 'TinyFFN_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_2", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [192, 768]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_3', 'FlashCausalSelfAttention_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_3", "key_type": 2}, {"key": "FlashCausalSelfAttention_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 2304]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_3', 'FlashCausalSelfAttention_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_3", "key_type": 2}, {"key": "FlashCausalSelfAttention_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 768]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_3', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_3", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_3', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_3", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_3', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_3", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_3', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_3", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_3', 'TinyFFN_0', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_3", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [192]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_3', 'TinyFFN_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_3", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 1536]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_3', 'TinyFFN_0', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_3", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_3', 'TinyFFN_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_3", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [192, 768]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_4', 'FlashCausalSelfAttention_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_4", "key_type": 2}, {"key": "FlashCausalSelfAttention_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 2304]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_4', 'FlashCausalSelfAttention_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_4", "key_type": 2}, {"key": "FlashCausalSelfAttention_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 768]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_4', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_4", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_4', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_4", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_4', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_4", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_4', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_4", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_4', 'TinyFFN_0', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_4", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [192]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_4', 'TinyFFN_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_4", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 1536]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_4', 'TinyFFN_0', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_4", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_4', 'TinyFFN_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_4", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [192, 768]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_5', 'FlashCausalSelfAttention_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_5", "key_type": 2}, {"key": "FlashCausalSelfAttention_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 2304]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_5', 'FlashCausalSelfAttention_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_5", "key_type": 2}, {"key": "FlashCausalSelfAttention_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 768]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_5', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_5", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_5', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_5", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_5', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_5", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_5', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_5", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_5', 'TinyFFN_0', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_5", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [192]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_5', 'TinyFFN_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_5", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 1536]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_5', 'TinyFFN_0', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_5", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_5', 'TinyFFN_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_5", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [192, 768]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_6', 'FlashCausalSelfAttention_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_6", "key_type": 2}, {"key": "FlashCausalSelfAttention_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 2304]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_6', 'FlashCausalSelfAttention_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_6", "key_type": 2}, {"key": "FlashCausalSelfAttention_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 768]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_6', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_6", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_6', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_6", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_6', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_6", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_6', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_6", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_6', 'TinyFFN_0', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_6", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [192]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_6', 'TinyFFN_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_6", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 1536]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_6', 'TinyFFN_0', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_6", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_6', 'TinyFFN_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_6", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [192, 768]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_7', 'FlashCausalSelfAttention_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_7", "key_type": 2}, {"key": "FlashCausalSelfAttention_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 2304]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_7', 'FlashCausalSelfAttention_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_7", "key_type": 2}, {"key": "FlashCausalSelfAttention_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 768]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_7', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_7", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_7', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_7", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_7', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_7", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_7', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_7", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_7', 'TinyFFN_0', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_7", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [192]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_7', 'TinyFFN_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_7", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 1536]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_7', 'TinyFFN_0', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_7", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_7', 'TinyFFN_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_7", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [192, 768]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_8', 'FlashCausalSelfAttention_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_8", "key_type": 2}, {"key": "FlashCausalSelfAttention_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 2304]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_8', 'FlashCausalSelfAttention_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_8", "key_type": 2}, {"key": "FlashCausalSelfAttention_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 768]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_8', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_8", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_8', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_8", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_8', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_8", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_8', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_8", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_8', 'TinyFFN_0', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_8", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [192]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_8', 'TinyFFN_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_8", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 1536]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_8', 'TinyFFN_0', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_8", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_8', 'TinyFFN_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_8", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [192, 768]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_9', 'FlashCausalSelfAttention_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_9", "key_type": 2}, {"key": "FlashCausalSelfAttention_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 2304]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_9', 'FlashCausalSelfAttention_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_9", "key_type": 2}, {"key": "FlashCausalSelfAttention_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 768]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_9', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_9", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_9', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_9", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_9', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_9", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_9', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_9", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_9', 'TinyFFN_0', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_9", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [192]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_9', 'TinyFFN_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_9", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 1536]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_9', 'TinyFFN_0', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_9", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'mu', 'controller', 'layers_9', 'TinyFFN_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_9", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [192, 768]}}, "('opt_state', '1', '0', 'mu', 'controller', 'lm_head', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "lm_head", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 50257]}}, "('opt_state', '1', '0', 'mu', 'controller', 'pos_encoding', 'embedding')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "pos_encoding", "key_type": 2}, {"key": "embedding", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [128, 768]}}, "('opt_state', '1', '0', 'mu', 'indexer', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "indexer", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 1]}}, "('opt_state', '1', '0', 'mu', 'indexer', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "indexer", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'mu', 'indexer', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "indexer", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 768]}}, "('opt_state', '1', '0', 'mu', 'indexer', 'Dense_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "indexer", "key_type": 2}, {"key": "Dense_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [48]}}, "('opt_state', '1', '0', 'mu', 'indexer', 'Dense_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "indexer", "key_type": 2}, {"key": "Dense_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 384]}}, "('opt_state', '1', '0', 'mu', 'indexer', 'Dense_3', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "indexer", "key_type": 2}, {"key": "Dense_3", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1]}}, "('opt_state', '1', '0', 'mu', 'indexer', 'Dense_3', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "indexer", "key_type": 2}, {"key": "Dense_3", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [48, 1]}}, "('opt_state', '1', '0', 'mu', 'indexer', 'Dense_4', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "indexer", "key_type": 2}, {"key": "Dense_4", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1]}}, "('opt_state', '1', '0', 'mu', 'indexer', 'Dense_4', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "indexer", "key_type": 2}, {"key": "Dense_4", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [48, 1]}}, "('opt_state', '1', '0', 'mu', 'retrieval_integrator', 'layers_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "retrieval_integrator", "key_type": 2}, {"key": "layers_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'mu', 'retrieval_integrator', 'layers_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "retrieval_integrator", "key_type": 2}, {"key": "layers_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [192, 768]}}, "('opt_state', '1', '0', 'mu', 'retrieval_integrator', 'layers_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "retrieval_integrator", "key_type": 2}, {"key": "layers_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'mu', 'retrieval_integrator', 'layers_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "retrieval_integrator", "key_type": 2}, {"key": "layers_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 768]}}, "('opt_state', '1', '0', 'mu', 'retrieval_integrator', 'layers_3', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "retrieval_integrator", "key_type": 2}, {"key": "layers_3", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'mu', 'retrieval_integrator', 'layers_3', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "retrieval_integrator", "key_type": 2}, {"key": "layers_3", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'nu', 'acc', 'halt_net', 'layers_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "acc", "key_type": 2}, {"key": "halt_net", "key_type": 2}, {"key": "layers_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [24]}}, "('opt_state', '1', '0', 'nu', 'acc', 'halt_net', 'layers_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "acc", "key_type": 2}, {"key": "halt_net", "key_type": 2}, {"key": "layers_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 192]}}, "('opt_state', '1', '0', 'nu', 'acc', 'halt_net', 'layers_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "acc", "key_type": 2}, {"key": "halt_net", "key_type": 2}, {"key": "layers_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1]}}, "('opt_state', '1', '0', 'nu', 'acc', 'halt_net', 'layers_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "acc", "key_type": 2}, {"key": "halt_net", "key_type": 2}, {"key": "layers_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [24, 1]}}, "('opt_state', '1', '0', 'nu', 'acc', 'loop_embed', 'embedding')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "acc", "key_type": 2}, {"key": "loop_embed", "key_type": 2}, {"key": "embedding", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [4, 768]}}, "('opt_state', '1', '0', 'nu', 'acc', 'state_gate', 'layers_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "acc", "key_type": 2}, {"key": "state_gate", "key_type": 2}, {"key": "layers_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'nu', 'acc', 'state_gate', 'layers_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "acc", "key_type": 2}, {"key": "state_gate", "key_type": 2}, {"key": "layers_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [192, 768]}}, "('opt_state', '1', '0', 'nu', 'acc', 'state_norm', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "acc", "key_type": 2}, {"key": "state_norm", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'nu', 'acc', 'state_norm', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "acc", "key_type": 2}, {"key": "state_norm", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'nu', 'acc', 'state_transform', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "acc", "key_type": 2}, {"key": "state_transform", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'nu', 'acc', 'state_transform', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "acc", "key_type": 2}, {"key": "state_transform", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 768]}}, "('opt_state', '1', '0', 'nu', 'controller', 'embedding', 'embedding')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "embedding", "key_type": 2}, {"key": "embedding", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [50257, 96]}}, "('opt_state', '1', '0', 'nu', 'controller', 'final_norm', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "final_norm", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'nu', 'controller', 'final_norm', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "final_norm", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_0', 'FlashCausalSelfAttention_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_0", "key_type": 2}, {"key": "FlashCausalSelfAttention_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 2304]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_0', 'FlashCausalSelfAttention_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_0", "key_type": 2}, {"key": "FlashCausalSelfAttention_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 768]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_0', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_0", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_0', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_0", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_0', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_0", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_0', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_0", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_0', 'TinyFFN_0', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_0", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [192]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_0', 'TinyFFN_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_0", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 1536]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_0', 'TinyFFN_0', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_0", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_0', 'TinyFFN_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_0", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [192, 768]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_1', 'FlashCausalSelfAttention_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_1", "key_type": 2}, {"key": "FlashCausalSelfAttention_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 2304]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_1', 'FlashCausalSelfAttention_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_1", "key_type": 2}, {"key": "FlashCausalSelfAttention_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 768]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_1', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_1", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_1', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_1", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_1', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_1", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_1', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_1", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_1', 'TinyFFN_0', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_1", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [192]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_1', 'TinyFFN_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_1", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 1536]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_1', 'TinyFFN_0', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_1", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_1', 'TinyFFN_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_1", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [192, 768]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_10', 'FlashCausalSelfAttention_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_10", "key_type": 2}, {"key": "FlashCausalSelfAttention_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 2304]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_10', 'FlashCausalSelfAttention_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_10", "key_type": 2}, {"key": "FlashCausalSelfAttention_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 768]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_10', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_10", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_10', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_10", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_10', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_10", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_10', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_10", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_10', 'TinyFFN_0', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_10", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [192]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_10', 'TinyFFN_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_10", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 1536]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_10', 'TinyFFN_0', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_10", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_10', 'TinyFFN_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_10", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [192, 768]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_11', 'FlashCausalSelfAttention_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_11", "key_type": 2}, {"key": "FlashCausalSelfAttention_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 2304]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_11', 'FlashCausalSelfAttention_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_11", "key_type": 2}, {"key": "FlashCausalSelfAttention_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 768]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_11', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_11", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_11', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_11", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_11', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_11", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_11', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_11", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_11', 'TinyFFN_0', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_11", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [192]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_11', 'TinyFFN_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_11", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 1536]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_11', 'TinyFFN_0', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_11", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_11', 'TinyFFN_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_11", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [192, 768]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_2', 'FlashCausalSelfAttention_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_2", "key_type": 2}, {"key": "FlashCausalSelfAttention_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 2304]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_2', 'FlashCausalSelfAttention_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_2", "key_type": 2}, {"key": "FlashCausalSelfAttention_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 768]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_2', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_2", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_2', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_2", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_2', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_2", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_2', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_2", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_2', 'TinyFFN_0', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_2", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [192]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_2', 'TinyFFN_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_2", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 1536]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_2', 'TinyFFN_0', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_2", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_2', 'TinyFFN_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_2", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [192, 768]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_3', 'FlashCausalSelfAttention_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_3", "key_type": 2}, {"key": "FlashCausalSelfAttention_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 2304]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_3', 'FlashCausalSelfAttention_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_3", "key_type": 2}, {"key": "FlashCausalSelfAttention_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 768]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_3', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_3", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_3', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_3", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_3', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_3", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_3', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_3", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_3', 'TinyFFN_0', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_3", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [192]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_3', 'TinyFFN_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_3", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 1536]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_3', 'TinyFFN_0', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_3", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_3', 'TinyFFN_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_3", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [192, 768]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_4', 'FlashCausalSelfAttention_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_4", "key_type": 2}, {"key": "FlashCausalSelfAttention_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 2304]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_4', 'FlashCausalSelfAttention_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_4", "key_type": 2}, {"key": "FlashCausalSelfAttention_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 768]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_4', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_4", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_4', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_4", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_4', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_4", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_4', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_4", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_4', 'TinyFFN_0', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_4", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [192]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_4', 'TinyFFN_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_4", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 1536]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_4', 'TinyFFN_0', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_4", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_4', 'TinyFFN_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_4", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [192, 768]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_5', 'FlashCausalSelfAttention_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_5", "key_type": 2}, {"key": "FlashCausalSelfAttention_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 2304]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_5', 'FlashCausalSelfAttention_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_5", "key_type": 2}, {"key": "FlashCausalSelfAttention_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 768]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_5', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_5", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_5', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_5", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_5', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_5", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_5', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_5", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_5', 'TinyFFN_0', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_5", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [192]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_5', 'TinyFFN_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_5", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 1536]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_5', 'TinyFFN_0', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_5", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_5', 'TinyFFN_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_5", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [192, 768]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_6', 'FlashCausalSelfAttention_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_6", "key_type": 2}, {"key": "FlashCausalSelfAttention_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 2304]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_6', 'FlashCausalSelfAttention_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_6", "key_type": 2}, {"key": "FlashCausalSelfAttention_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 768]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_6', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_6", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_6', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_6", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_6', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_6", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_6', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_6", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_6', 'TinyFFN_0', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_6", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [192]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_6', 'TinyFFN_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_6", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 1536]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_6', 'TinyFFN_0', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_6", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_6', 'TinyFFN_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_6", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [192, 768]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_7', 'FlashCausalSelfAttention_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_7", "key_type": 2}, {"key": "FlashCausalSelfAttention_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 2304]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_7', 'FlashCausalSelfAttention_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_7", "key_type": 2}, {"key": "FlashCausalSelfAttention_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 768]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_7', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_7", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_7', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_7", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_7', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_7", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_7', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_7", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_7', 'TinyFFN_0', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_7", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [192]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_7', 'TinyFFN_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_7", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 1536]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_7', 'TinyFFN_0', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_7", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_7', 'TinyFFN_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_7", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [192, 768]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_8', 'FlashCausalSelfAttention_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_8", "key_type": 2}, {"key": "FlashCausalSelfAttention_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 2304]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_8', 'FlashCausalSelfAttention_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_8", "key_type": 2}, {"key": "FlashCausalSelfAttention_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 768]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_8', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_8", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_8', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_8", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_8', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_8", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_8', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_8", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_8', 'TinyFFN_0', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_8", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [192]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_8', 'TinyFFN_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_8", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 1536]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_8', 'TinyFFN_0', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_8", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_8', 'TinyFFN_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_8", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [192, 768]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_9', 'FlashCausalSelfAttention_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_9", "key_type": 2}, {"key": "FlashCausalSelfAttention_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 2304]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_9', 'FlashCausalSelfAttention_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_9", "key_type": 2}, {"key": "FlashCausalSelfAttention_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 768]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_9', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_9", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_9', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_9", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_9', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_9", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_9', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_9", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_9', 'TinyFFN_0', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_9", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [192]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_9', 'TinyFFN_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_9", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 1536]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_9', 'TinyFFN_0', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_9", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'nu', 'controller', 'layers_9', 'TinyFFN_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "layers_9", "key_type": 2}, {"key": "TinyFFN_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [192, 768]}}, "('opt_state', '1', '0', 'nu', 'controller', 'lm_head', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "lm_head", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 50257]}}, "('opt_state', '1', '0', 'nu', 'controller', 'pos_encoding', 'embedding')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "controller", "key_type": 2}, {"key": "pos_encoding", "key_type": 2}, {"key": "embedding", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [128, 768]}}, "('opt_state', '1', '0', 'nu', 'indexer', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "indexer", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 1]}}, "('opt_state', '1', '0', 'nu', 'indexer', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "indexer", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'nu', 'indexer', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "indexer", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 768]}}, "('opt_state', '1', '0', 'nu', 'indexer', 'Dense_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "indexer", "key_type": 2}, {"key": "Dense_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [48]}}, "('opt_state', '1', '0', 'nu', 'indexer', 'Dense_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "indexer", "key_type": 2}, {"key": "Dense_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 384]}}, "('opt_state', '1', '0', 'nu', 'indexer', 'Dense_3', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "indexer", "key_type": 2}, {"key": "Dense_3", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1]}}, "('opt_state', '1', '0', 'nu', 'indexer', 'Dense_3', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "indexer", "key_type": 2}, {"key": "Dense_3", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [48, 1]}}, "('opt_state', '1', '0', 'nu', 'indexer', 'Dense_4', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "indexer", "key_type": 2}, {"key": "Dense_4", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1]}}, "('opt_state', '1', '0', 'nu', 'indexer', 'Dense_4', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "indexer", "key_type": 2}, {"key": "Dense_4", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [48, 1]}}, "('opt_state', '1', '0', 'nu', 'retrieval_integrator', 'layers_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "retrieval_integrator", "key_type": 2}, {"key": "layers_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'nu', 'retrieval_integrator', 'layers_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "retrieval_integrator", "key_type": 2}, {"key": "layers_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [192, 768]}}, "('opt_state', '1', '0', 'nu', 'retrieval_integrator', 'layers_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "retrieval_integrator", "key_type": 2}, {"key": "layers_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'nu', 'retrieval_integrator', 'layers_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "retrieval_integrator", "key_type": 2}, {"key": "layers_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96, 768]}}, "('opt_state', '1', '0', 'nu', 'retrieval_integrator', 'layers_3', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "retrieval_integrator", "key_type": 2}, {"key": "layers_3", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '0', 'nu', 'retrieval_integrator', 'layers_3', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "retrieval_integrator", "key_type": 2}, {"key": "layers_3", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [96]}}, "('opt_state', '1', '1')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "1", "key_type": 1}], "value_metadata": {"value_type": "None", "skip_deserialize": true}}, "('opt_state', '1', '2', 'count')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "2", "key_type": 1}, {"key": "count", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": []}}, "('rng',)": {"key_metadata": [{"key": "rng", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [2]}}, "('pool_m',)": {"key_metadata": [{"key": "pool_m", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [32768, 768]}}, "('pool_v',)": {"key_metadata": [{"key": "pool_v", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [32768, 768]}}}, "use_ocdbt": true, "use_zarr3": false, "store_array_data_equal_to_fill_value": true, "custom_metadata": null} \ No newline at end of file