diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..ca6b382d9dbc68686689c5fda8a58348654ed43a 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,121 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/ocdbt.process_0/d/0cf9a08a9722f9b8a0b7f007da7c1e92 filter=lfs diff=lfs merge=lfs -text
+checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/ocdbt.process_0/d/d32b0ea672fe7a9b86b8e62e7c20dbaf filter=lfs diff=lfs merge=lfs -text
+checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/ocdbt.process_0/d/fd4cdc0c7be44d4f518c6ccca7ad654a filter=lfs diff=lfs merge=lfs -text
+checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/ocdbt.process_0/d/369457b7c6608f1adf28eb88024d6b91 filter=lfs diff=lfs merge=lfs -text
+checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/ocdbt.process_0/d/6bb110a840637eede93e25f5356236f9 filter=lfs diff=lfs merge=lfs -text
+checkpoints/ppo_agents/Craftax-Symbolic-v1-PPO_RNN-1000M/1000000000/default/ocdbt.process_0/d/e5a4020f50167115120fe5dac41c20fb filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_breakdown.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_action_diversity.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_advantage_clip.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_attention_only.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_baseline_rl.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_bc_wins.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_entropy_bonus.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_ewc.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_ffn_only.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_frozen_backbone.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_gradient_surgery.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_head_only.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_kl_penalty.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_layer_ablation_top1.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_layer_ablation_top2.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_layer_ablation_top3.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_llrd.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_lora.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_low_t.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_mixed_replay.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_normalized_adv.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_reward_filtering.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_reward_model.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_running_stats.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_t_curriculum.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_trust_region_kl.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/js_divergence_comparison.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/cka_similarity.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_action_diversity.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_advantage_clip.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_attention_only.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_baseline_rl.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_bc_wins.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_entropy_bonus.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_ewc.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_ffn_only.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_frozen_backbone.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_gradient_surgery.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_head_only.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_kl_penalty.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_layer_ablation_top1.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_layer_ablation_top2.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_layer_ablation_top3.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_llrd.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_lora.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_low_t.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_mixed_replay.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_normalized_adv.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_reward_filtering.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_reward_model.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_running_stats.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_t_curriculum.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_trust_region_kl.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/diagnosis_decision_tree.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/eval_scores_over_training.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/final_score_comparison.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/gradient_alignment.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/gradient_conflict_map.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_action_diversity.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_advantage_clip.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_attention_only.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_baseline_rl.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_bc_wins.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_entropy_bonus.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_ewc.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_ffn_only.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_frozen_backbone.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_gradient_surgery.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_head_only.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_kl_penalty.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_layer_ablation_top1.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_layer_ablation_top2.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_layer_ablation_top3.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_llrd.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_lora.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_low_t.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_mixed_replay.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_normalized_adv.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_reward_filtering.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_reward_model.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_running_stats.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_t_curriculum.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_trust_region_kl.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/representation_drift.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/score_delta_over_baseline_rl.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_action_diversity.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_advantage_clip.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_attention_only.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_baseline_rl.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_bc_wins.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_entropy_bonus.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_ewc.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_ffn_only.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_frozen_backbone.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_gradient_surgery.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_head_only.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_kl_penalty.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_layer_ablation_top1.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_layer_ablation_top2.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_layer_ablation_top3.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_llrd.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_low_t.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_mixed_replay.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_normalized_adv.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_reward_filtering.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_reward_model.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_running_stats.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_t_curriculum.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_trust_region_kl.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_norms_heatmap.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_distribution_analysis.png filter=lfs diff=lfs merge=lfs -text
+experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/win_rate_and_effective_batch_size.png filter=lfs diff=lfs merge=lfs -text
diff --git a/Craftax_Baselines/.gitignore b/Craftax_Baselines/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..dfa0dc993a3761abbcb3138721b366740183e49c
--- /dev/null
+++ b/Craftax_Baselines/.gitignore
@@ -0,0 +1,169 @@
+tmp/
+wandb/
+res/
+runs/
+
+play_data
+
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+cover/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+.pybuilder/
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+# For a library or package, you might want to ignore these files since the code is
+# intended to run in multiple environments; otherwise, check them in:
+# .python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# poetry
+# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
+# This is especially recommended for binary packages to ensure reproducibility, and is more
+# commonly ignored for libraries.
+# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
+#poetry.lock
+
+# pdm
+# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
+#pdm.lock
+# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
+# in version control.
+# https://pdm.fming.dev/#use-with-ide
+.pdm.toml
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# pytype static type analyzer
+.pytype/
+
+# Cython debug symbols
+cython_debug/
+
+# PyCharm
+# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
+# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
+# and can be added to the global gitignore or merged into this file. For a more nuclear
+# option (not recommended) you can uncomment the following to ignore the entire idea folder.
+.idea/
+texture_cache.pbz2
+texture_cache*.pbz2
diff --git a/Craftax_Baselines/.pre-commit-config.yaml b/Craftax_Baselines/.pre-commit-config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ad94cafd5029ff0caaa470d5a76143bd92d6150e
--- /dev/null
+++ b/Craftax_Baselines/.pre-commit-config.yaml
@@ -0,0 +1,6 @@
+repos:
+- repo: https://github.com/psf/black
+ rev: 22.3.0
+ hooks:
+ - id: black
+ language_version: python3
\ No newline at end of file
diff --git a/Craftax_Baselines/Dockerfile b/Craftax_Baselines/Dockerfile
new file mode 100644
index 0000000000000000000000000000000000000000..55ef0a0d7afec9bfafdc33f8510d7e2ff69385a5
--- /dev/null
+++ b/Craftax_Baselines/Dockerfile
@@ -0,0 +1,41 @@
+FROM nvidia/cuda:12.1.0-cudnn8-devel-ubuntu22.04
+
+ENV CUDA_PATH /usr/local/cuda
+ENV CUDA_INCLUDE_PATH /usr/local/cuda/include
+ENV CUDA_LIBRARY_PATH /usr/local/cuda/lib64
+
+# Set timezone
+ENV TZ=Europe/London DEBIAN_FRONTEND=noninteractive
+
+# Add Python 3.8 to Ubuntu 22.04 and install dependencies
+RUN apt update
+RUN apt install -y software-properties-common && add-apt-repository ppa:deadsnakes/ppa
+RUN apt install -y \
+ git \
+ python3.8 \
+ python3-pip \
+ python3.8-venv \
+ python3-setuptools \
+ python3-wheel
+
+# Create local user
+# https://jtreminio.com/blog/running-docker-containers-as-current-host-user/
+ARG UID
+ARG GID
+RUN if [ ${UID:-0} -ne 0 ] && [ ${GID:-0} -ne 0 ]; then \
+ groupadd -g ${GID} duser &&\
+ useradd -l -u ${UID} -g duser duser &&\
+ install -d -m 0755 -o duser -g duser /home/duser &&\
+ chown --changes --silent --no-dereference --recursive ${UID}:${GID} /home/duser \
+ ;fi
+
+USER duser
+WORKDIR /home/duser
+
+# Install Python packages
+ENV PATH="/home/duser/.local/bin:$PATH"
+RUN python3 -m pip install --upgrade pip
+ARG REQS
+RUN pip install $REQS -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
+
+WORKDIR /home/duser/Craftax
diff --git a/Craftax_Baselines/LICENSE b/Craftax_Baselines/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..70b0f5878e5bd2bc15aec41eb21b98f5a4571ebd
--- /dev/null
+++ b/Craftax_Baselines/LICENSE
@@ -0,0 +1,19 @@
+Copyright (c) 2024 Michael Matthews
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
\ No newline at end of file
diff --git a/Craftax_Baselines/README.md b/Craftax_Baselines/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..0f521ba316623e508c55bb33531db221e7f1a64b
--- /dev/null
+++ b/Craftax_Baselines/README.md
@@ -0,0 +1,46 @@
+
+
+
+
+# Craftax Baselines
+
+This repository contains the code for running the baselines from the [Craftax paper](https://arxiv.org/abs/2402.16801).
+For packaging reasons, this is separate to the [main repository](https://github.com/MichaelTMatthews/Craftax/).
+
+# Installation
+```commandline
+git clone https://github.com/MichaelTMatthews/Craftax_Baselines.git
+cd Craftax_Baselines
+pip install -r requirements.txt -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
+pre-commit install
+```
+
+# Run Experiments
+
+### PPO
+```commandline
+python ppo.py
+```
+
+### PPO-RNN
+```commandline
+python ppo_rnn.py
+```
+
+### ICM
+```commandline
+python ppo.py --train_icm
+```
+
+### E3B
+```commandline
+python ppo.py --train_icm --use_e3b --icm_reward_coeff 0
+```
+
+### RND
+```commandline
+python ppo_rnd.py
+```
+
+# Visualisation
+You can save trained policies with the `--save_policy` flag. These can then be viewed with the `view_ppo_agent` script (pass in the path up to the `files` directory).
\ No newline at end of file
diff --git a/Craftax_Baselines/analysis/__init__.py b/Craftax_Baselines/analysis/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/Craftax_Baselines/analysis/view_ppo_agent.py b/Craftax_Baselines/analysis/view_ppo_agent.py
new file mode 100644
index 0000000000000000000000000000000000000000..c48da547f9b92025515c4c8d96e29881237e9e23
--- /dev/null
+++ b/Craftax_Baselines/analysis/view_ppo_agent.py
@@ -0,0 +1,151 @@
+import argparse
+import os
+import sys
+
+import jax
+import jax.numpy as jnp
+import numpy as np
+import optax
+import yaml
+from craftax.environment_base.wrappers import AutoResetEnvWrapper
+from flax.training.train_state import TrainState
+import orbax.checkpoint as ocp
+
+from ..models.actor_critic import ActorCriticConv, ActorCritic
+
+
+def main(args):
+
+ with open(os.path.join(args.path, "config.yaml")) as f:
+ raw_config = yaml.load(f, Loader=yaml.Loader)
+
+ config = {}
+ for key, value in raw_config.items():
+ if isinstance(value, dict) and "value" in value:
+ config[key] = value["value"]
+
+ config["NUM_ENVS"] = 1
+
+ options = ocp.CheckpointManagerOptions(max_to_keep=1)
+ checkpoint_manager = ocp.CheckpointManager(
+ os.path.join(args.path, "policies"),
+ options=options
+ )
+
+ is_classic = False
+
+ if config["ENV_NAME"] == "Craftax-Symbolic-v1":
+ from craftax.craftax.envs.craftax_symbolic_env import CraftaxSymbolicEnv
+ from craftax.craftax.constants import Action
+
+ env = CraftaxSymbolicEnv(CraftaxSymbolicEnv.default_static_params())
+ network = ActorCritic(len(Action), config["LAYER_SIZE"])
+ elif config["ENV_NAME"] == "Craftax-Pixels-v1":
+ from craftax.craftax.envs.craftax_pixels_env import CraftaxPixelsEnv
+ from craftax.craftax.constants import Action
+
+ env = CraftaxPixelsEnv(CraftaxPixelsEnv.default_static_params())
+ network = ActorCriticConv(len(Action), config["LAYER_SIZE"])
+ elif config["ENV_NAME"] == "Craftax-Classic-Symbolic-v1":
+ from craftax.craftax_classic.envs.craftax_symbolic_env import (
+ CraftaxClassicSymbolicEnv,
+ )
+ from craftax.craftax_classic.constants import Action
+
+ env = CraftaxClassicSymbolicEnv(
+ CraftaxClassicSymbolicEnv.default_static_params()
+ )
+ network = ActorCritic(len(Action), config["LAYER_SIZE"])
+ is_classic = True
+ elif config["ENV_NAME"] == "Craftax-Classic-Pixels-v1":
+ from craftax.craftax_classic.envs.craftax_pixels_env import (
+ CraftaxClassicPixelsEnv,
+ )
+ from craftax.craftax_classic.constants import Action
+
+ env = CraftaxClassicPixelsEnv(CraftaxClassicPixelsEnv.default_static_params())
+ network = ActorCriticConv(len(Action), config["LAYER_SIZE"])
+ is_classic = True
+ else:
+ raise ValueError(f"Unknown env: {config['ENV_NAME']}")
+
+ env = AutoResetEnvWrapper(env)
+ env_params = env.default_params
+
+ init_x = jnp.zeros((config["NUM_ENVS"], *env.observation_space(env_params).shape))
+
+ rng = jax.random.PRNGKey(np.random.randint(2**31))
+ rng, _rng, __rng = jax.random.split(rng, 3)
+ network_params = network.init(_rng, init_x)
+
+ tx = optax.chain(
+ optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
+ optax.adam(config["LR"], eps=1e-5),
+ )
+ train_state = TrainState.create(
+ apply_fn=network.apply,
+ params=network_params,
+ tx=tx,
+ )
+
+ abstract_train_state = jax.eval_shape(lambda: train_state)
+
+ train_state = checkpoint_manager.restore(
+ config["TOTAL_TIMESTEPS"],
+ args=ocp.args.StandardRestore(abstract_train_state)
+ )
+
+ obs, env_state = env.reset(key=__rng)
+ done = 0
+
+ if is_classic:
+ from craftax.craftax_classic.play_craftax_classic import CraftaxRenderer
+ from craftax.craftax_classic.constants import Achievement
+ else:
+ from craftax.craftax.play_craftax import CraftaxRenderer
+ from craftax.craftax.constants import Achievement
+
+ renderer = CraftaxRenderer(env, env_params, pixel_render_size=1)
+
+ while not renderer.is_quit_requested():
+ done = np.array([done], dtype=bool)
+ obs = jnp.expand_dims(obs, axis=0)
+
+ pi, value = network.apply(train_state.params, obs)
+ rng, _rng = jax.random.split(rng)
+ action = pi.sample(seed=_rng)[0]
+ # action = jnp.argmax(pi.probs[0, 0])
+
+ if action is not None:
+ rng, _rng = jax.random.split(rng)
+ old_achievements = env_state.achievements
+ obs, env_state, reward, done, info = env.step(
+ _rng, env_state, action, env_params
+ )
+ new_achievements = env_state.achievements
+ print_new_achievements(Achievement, old_achievements, new_achievements)
+ if done:
+ print("\n")
+ renderer.render(env_state)
+
+
+def print_new_achievements(achievements_cls, old_achievements, new_achievements):
+ for i in range(len(old_achievements)):
+ if old_achievements[i] == 0 and new_achievements[i] == 1:
+ print(f"{achievements_cls(i).name} ({new_achievements.sum()}/{22})")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--path", type=str)
+ parser.add_argument("--debug", action="store_true")
+
+ args, rest_args = parser.parse_known_args(sys.argv[1:])
+ if rest_args:
+ raise ValueError(f"Unknown args {rest_args}")
+
+ if args.debug:
+ with jax.disable_jit():
+ main(args)
+ else:
+ main(args)
diff --git a/Craftax_Baselines/build.sh b/Craftax_Baselines/build.sh
new file mode 100644
index 0000000000000000000000000000000000000000..ebc6ecce6ac49188a51e9e31ebbc632bec9622a4
--- /dev/null
+++ b/Craftax_Baselines/build.sh
@@ -0,0 +1,10 @@
+#!/bin/bash
+
+echo 'Building Dockerfile with image name craftax'
+docker build \
+ --build-arg UID=$(id -u ${USER}) \
+ --build-arg GID=1234 \
+ --build-arg REQS="$(cat requirements.txt)" \
+ -t craftax_baselines \
+ --no-cache \
+ .
diff --git a/Craftax_Baselines/images/logo.png b/Craftax_Baselines/images/logo.png
new file mode 100644
index 0000000000000000000000000000000000000000..baebfeda0d5e6ce832bc0710afd8c1ab14859434
Binary files /dev/null and b/Craftax_Baselines/images/logo.png differ
diff --git a/Craftax_Baselines/logz/__init__.py b/Craftax_Baselines/logz/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/Craftax_Baselines/logz/batch_logging.py b/Craftax_Baselines/logz/batch_logging.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c5c397478773a0f085abde3a72cc60cf5d217be
--- /dev/null
+++ b/Craftax_Baselines/logz/batch_logging.py
@@ -0,0 +1,115 @@
+import time
+
+import jax.numpy as jnp
+import numpy as np
+import wandb
+
+batch_logs = {}
+log_times = []
+
+
+def create_log_dict(info, config):
+ to_log = {
+ "episode_return": info["returned_episode_returns"],
+ "episode_length": info["returned_episode_lengths"],
+ }
+
+ diffusion_keys = [
+ "loss", "unweighted_loss", "accuracy", "mean_t",
+ "acc_t_low", "acc_t_mid", "acc_t_high", "grad_norm",
+ "action_entropy", "action_unique_frac"
+ ]
+ for k in diffusion_keys:
+ if k in info:
+ to_log[f"diffusion/{k}"] = info[k]
+
+ sum_achievements = 0.0
+ sum_val_achievements = 0.0
+ has_val = False
+
+ for k, v in info.items():
+ if k.startswith("val/"):
+ has_val = True
+ to_log[k] = v
+ if "achievements" in k.lower() and k != "val/achievements":
+ sum_val_achievements += v / 100.0
+ elif "achievements" in k.lower():
+ to_log[k] = v
+ if k != "achievements":
+ sum_achievements += v / 100.0
+
+ to_log["achievements"] = sum_achievements
+ if has_val:
+ to_log["val/achievements"] = sum_val_achievements
+
+ if config.get("TRAIN_ICM") or config.get("USE_RND"):
+ to_log["intrinsic_reward"] = info.get("reward_i", 0.0)
+ to_log["extrinsic_reward"] = info.get("reward_e", 0.0)
+
+ if config.get("TRAIN_ICM"):
+ to_log["icm_inverse_loss"] = info.get("icm_inverse_loss", 0.0)
+ to_log["icm_forward_loss"] = info.get("icm_forward_loss", 0.0)
+ elif config.get("USE_RND"):
+ to_log["rnd_loss"] = info.get("rnd_loss", 0.0)
+
+ return to_log
+
+
+def batch_log(update_step, log, config):
+ update_step = int(update_step)
+ if update_step not in batch_logs:
+ batch_logs[update_step] = []
+
+ batch_logs[update_step].append(log)
+
+ if len(batch_logs[update_step]) == config.get("NUM_REPEATS", 1):
+ agg_logs = {}
+ for key in batch_logs[update_step][0]:
+ agg = []
+ if key in ["goal_heatmap"]:
+ agg = [batch_logs[update_step][0][key]]
+ else:
+ for i in range(config.get("NUM_REPEATS", 1)):
+ # Use .get() to prevent KeyErrors if repeats are out of sync
+ val = batch_logs[update_step][i].get(key, float("nan"))
+ if not jnp.isnan(val):
+ agg.append(val)
+
+ if len(agg) > 0:
+ if key in [
+ "episode_length",
+ "episode_return",
+ "exploration_bonus",
+ "e_mean",
+ "e_std",
+ "rnd_loss",
+ "diffusion/loss",
+ "diffusion/unweighted_loss",
+ "diffusion/accuracy",
+ "diffusion/acc_t_low",
+ "diffusion/acc_t_mid",
+ "diffusion/acc_t_high",
+ "diffusion/action_entropy",
+ "diffusion/grad_norm"
+ ] or key.startswith("val/") or "achievement" in key.lower():
+ agg_logs[key] = np.mean(agg)
+ else:
+ agg_logs[key] = np.array(agg)
+
+ log_times.append(time.time())
+
+ if config.get("DEBUG"):
+ if len(log_times) == 1:
+ print("Started logging")
+ elif len(log_times) > 1:
+ dt = log_times[-1] - log_times[-2]
+ steps_between_updates = (
+ config["NUM_STEPS"] * config["NUM_ENVS"] * config.get("NUM_REPEATS", 1)
+ )
+ sps = steps_between_updates / dt
+ agg_logs["sps"] = sps
+
+ wandb.log(agg_logs)
+
+ # Clear buffer to prevent memory leaks
+ del batch_logs[update_step]
\ No newline at end of file
diff --git a/Craftax_Baselines/models/__init__.py b/Craftax_Baselines/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/Craftax_Baselines/models/actor_critic.py b/Craftax_Baselines/models/actor_critic.py
new file mode 100644
index 0000000000000000000000000000000000000000..0bd3b7a31e81caf0cf6d17f225c8828f9cbe8a10
--- /dev/null
+++ b/Craftax_Baselines/models/actor_critic.py
@@ -0,0 +1,256 @@
+import jax.numpy as jnp
+import flax.linen as nn
+import numpy as np
+from flax.linen.initializers import constant, orthogonal
+from typing import Sequence
+
+import distrax
+
+
+class ActorCriticConvSymbolicCraftax(nn.Module):
+ action_dim: int
+ map_obs_shape: Sequence[int]
+ layer_width: int
+
+ @nn.compact
+ def __call__(self, obs):
+ # Split into map and flat obs
+ flat_map_obs_shape = (
+ self.map_obs_shape[0] * self.map_obs_shape[1] * self.map_obs_shape[2]
+ )
+ image_obs = obs[:, :flat_map_obs_shape]
+ image_dim = self.map_obs_shape
+ image_obs = image_obs.reshape((image_obs.shape[0], *image_dim))
+
+ flat_obs = obs[:, flat_map_obs_shape:]
+
+ # Convolutions on map
+ image_embedding = nn.Conv(features=32, kernel_size=(2, 2))(image_obs)
+ image_embedding = nn.relu(image_embedding)
+ image_embedding = nn.max_pool(
+ image_embedding, window_shape=(2, 2), strides=(1, 1)
+ )
+ image_embedding = nn.Conv(features=32, kernel_size=(2, 2))(image_embedding)
+ image_embedding = nn.relu(image_embedding)
+ image_embedding = nn.max_pool(
+ image_embedding, window_shape=(2, 2), strides=(1, 1)
+ )
+ image_embedding = image_embedding.reshape(image_embedding.shape[0], -1)
+ # image_embedding = jnp.concatenate([image_embedding, obs[:, : CraftaxEnv.get_flat_map_obs_shape()]], axis=-1)
+
+ # Combine embeddings
+ embedding = jnp.concatenate([image_embedding, flat_obs], axis=-1)
+ embedding = nn.Dense(
+ self.layer_width, kernel_init=orthogonal(2), bias_init=constant(0.0)
+ )(embedding)
+ embedding = nn.relu(embedding)
+
+ actor_mean = nn.Dense(
+ self.layer_width, kernel_init=orthogonal(2), bias_init=constant(0.0)
+ )(embedding)
+ actor_mean = nn.relu(actor_mean)
+
+ actor_mean = nn.Dense(
+ self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0)
+ )(actor_mean)
+ actor_mean = nn.relu(actor_mean)
+
+ actor_mean = nn.Dense(
+ self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0)
+ )(actor_mean)
+
+ pi = distrax.Categorical(logits=actor_mean)
+
+ critic = nn.Dense(
+ self.layer_width, kernel_init=orthogonal(2), bias_init=constant(0.0)
+ )(embedding)
+ critic = nn.relu(critic)
+ critic = nn.Dense(
+ self.layer_width, kernel_init=orthogonal(2), bias_init=constant(0.0)
+ )(critic)
+ critic = nn.relu(critic)
+ critic = nn.Dense(
+ self.layer_width, kernel_init=orthogonal(2), bias_init=constant(0.0)
+ )(critic)
+ critic = nn.relu(critic)
+ critic = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))(
+ critic
+ )
+
+ return pi, jnp.squeeze(critic, axis=-1)
+
+
+class ActorCriticConv(nn.Module):
+ action_dim: int
+ layer_width: int
+ activation: str = "tanh"
+
+ @nn.compact
+ def __call__(self, obs):
+ x = nn.Conv(features=32, kernel_size=(5, 5))(obs)
+ x = nn.relu(x)
+ x = nn.max_pool(x, window_shape=(3, 3), strides=(3, 3))
+ x = nn.Conv(features=32, kernel_size=(5, 5))(x)
+ x = nn.relu(x)
+ x = nn.max_pool(x, window_shape=(3, 3), strides=(3, 3))
+ x = nn.Conv(features=32, kernel_size=(5, 5))(x)
+ x = nn.relu(x)
+ x = nn.max_pool(x, window_shape=(3, 3), strides=(3, 3))
+
+ embedding = x.reshape(x.shape[0], -1)
+
+ actor_mean = nn.Dense(
+ self.layer_width, kernel_init=orthogonal(2), bias_init=constant(0.0)
+ )(embedding)
+ actor_mean = nn.relu(actor_mean)
+
+ actor_mean = nn.Dense(
+ self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0)
+ )(actor_mean)
+ actor_mean = nn.relu(actor_mean)
+
+ actor_mean = nn.Dense(
+ self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0)
+ )(actor_mean)
+
+ pi = distrax.Categorical(logits=actor_mean)
+
+ critic = nn.Dense(
+ self.layer_width, kernel_init=orthogonal(2), bias_init=constant(0.0)
+ )(embedding)
+ critic = nn.relu(critic)
+ critic = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))(
+ critic
+ )
+
+ return pi, jnp.squeeze(critic, axis=-1)
+
+
+class ActorCritic(nn.Module):
+ action_dim: int
+ layer_width: int
+ activation: str = "tanh"
+
+ @nn.compact
+ def __call__(self, x):
+ if self.activation == "relu":
+ activation = nn.relu
+ else:
+ activation = nn.tanh
+
+ actor_mean = nn.Dense(
+ self.layer_width,
+ kernel_init=orthogonal(np.sqrt(2)),
+ bias_init=constant(0.0),
+ )(x)
+ actor_mean = activation(actor_mean)
+
+ actor_mean = nn.Dense(
+ self.layer_width,
+ kernel_init=orthogonal(np.sqrt(2)),
+ bias_init=constant(0.0),
+ )(actor_mean)
+ actor_mean = activation(actor_mean)
+
+ actor_mean = nn.Dense(
+ self.layer_width,
+ kernel_init=orthogonal(np.sqrt(2)),
+ bias_init=constant(0.0),
+ )(actor_mean)
+ actor_mean = activation(actor_mean)
+
+ actor_mean = nn.Dense(
+ self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0)
+ )(actor_mean)
+ pi = distrax.Categorical(logits=actor_mean)
+
+ critic = nn.Dense(
+ self.layer_width,
+ kernel_init=orthogonal(np.sqrt(2)),
+ bias_init=constant(0.0),
+ )(x)
+ critic = activation(critic)
+
+ critic = nn.Dense(
+ self.layer_width,
+ kernel_init=orthogonal(np.sqrt(2)),
+ bias_init=constant(0.0),
+ )(critic)
+ critic = activation(critic)
+
+ critic = nn.Dense(
+ self.layer_width,
+ kernel_init=orthogonal(np.sqrt(2)),
+ bias_init=constant(0.0),
+ )(critic)
+ critic = activation(critic)
+
+ critic = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))(
+ critic
+ )
+
+ return pi, jnp.squeeze(critic, axis=-1)
+
+
+class ActorCriticWithEmbedding(nn.Module):
+ action_dim: int
+ layer_width: int
+ activation: str = "tanh"
+
+ @nn.compact
+ def __call__(self, x):
+ if self.activation == "relu":
+ activation = nn.relu
+ else:
+ activation = nn.tanh
+
+ actor_emb = nn.Dense(
+ self.layer_width,
+ kernel_init=orthogonal(np.sqrt(2)),
+ bias_init=constant(0.0),
+ )(x)
+ actor_emb = activation(actor_emb)
+
+ actor_emb = nn.Dense(
+ self.layer_width,
+ kernel_init=orthogonal(np.sqrt(2)),
+ bias_init=constant(0.0),
+ )(actor_emb)
+ actor_emb = activation(actor_emb)
+
+ actor_emb = nn.Dense(
+ 128, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
+ )(actor_emb)
+ actor_emb = activation(actor_emb)
+
+ actor_mean = nn.Dense(
+ self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0)
+ )(actor_emb)
+ pi = distrax.Categorical(logits=actor_mean)
+
+ critic = nn.Dense(
+ self.layer_width,
+ kernel_init=orthogonal(np.sqrt(2)),
+ bias_init=constant(0.0),
+ )(x)
+ critic = activation(critic)
+
+ critic = nn.Dense(
+ self.layer_width,
+ kernel_init=orthogonal(np.sqrt(2)),
+ bias_init=constant(0.0),
+ )(critic)
+ critic = activation(critic)
+
+ critic = nn.Dense(
+ self.layer_width,
+ kernel_init=orthogonal(np.sqrt(2)),
+ bias_init=constant(0.0),
+ )(critic)
+ critic = activation(critic)
+
+ critic = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))(
+ critic
+ )
+
+ return pi, jnp.squeeze(critic, axis=-1), actor_emb
diff --git a/Craftax_Baselines/models/icm.py b/Craftax_Baselines/models/icm.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d9154f417276c61c4f94122a94001aa7eecaec3
--- /dev/null
+++ b/Craftax_Baselines/models/icm.py
@@ -0,0 +1,72 @@
+import jax
+import jax.numpy as jnp
+import flax.linen as nn
+
+
+class ICMEncoder(nn.Module):
+ layer_size: int
+ output_dim: int
+ num_layers: int
+
+ @nn.compact
+ def __call__(self, obs):
+ activation = nn.relu
+
+ # TODO Look at weight inits
+
+ emb = obs
+ for _ in range(self.num_layers):
+ emb = nn.Dense(
+ self.layer_size,
+ )(emb)
+ emb = activation(emb)
+
+ emb = nn.Dense(self.output_dim)(emb)
+
+ return emb
+
+
+class ICMForward(nn.Module):
+ layer_size: int
+ output_dim: int
+ num_layers: int
+ num_actions: int
+
+ @nn.compact
+ def __call__(self, latent, action):
+ activation = nn.relu
+
+ action1h = jax.nn.one_hot(action, num_classes=self.num_actions)
+ emb = jnp.concatenate((latent, action1h), axis=-1)
+ for _ in range(self.num_layers):
+ emb = nn.Dense(
+ self.layer_size,
+ )(emb)
+ emb = activation(emb)
+
+ emb = nn.Dense(self.output_dim)(emb)
+
+ return emb
+
+
+class ICMInverse(nn.Module):
+ layer_size: int
+ output_dim: int
+ num_layers: int
+
+ @nn.compact
+ def __call__(self, latent, next_latent):
+ activation = nn.relu
+
+ emb = jnp.concatenate((latent, next_latent), axis=-1)
+ for _ in range(self.num_layers):
+ emb = nn.Dense(
+ self.layer_size,
+ )(emb)
+ emb = activation(emb)
+
+ action_raw = nn.Dense(self.output_dim)(emb)
+
+ action_logits = jax.nn.log_softmax(action_raw)
+
+ return action_logits
diff --git a/Craftax_Baselines/models/rnd.py b/Craftax_Baselines/models/rnd.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4e5f11e8c9f565270ba16595058ec719a3016a1
--- /dev/null
+++ b/Craftax_Baselines/models/rnd.py
@@ -0,0 +1,120 @@
+import jax.numpy as jnp
+import flax.linen as nn
+import numpy as np
+from flax.linen.initializers import constant, orthogonal
+
+import distrax
+
+
+class RNDNetwork(nn.Module):
+ layer_size: int
+ output_dim: int
+ num_layers: int
+
+ @nn.compact
+ def __call__(self, x):
+ activation = nn.relu
+
+ emb = x
+ for _ in range(self.num_layers):
+ emb = nn.Dense(
+ self.layer_size,
+ )(emb)
+ emb = activation(emb)
+
+ emb = nn.Dense(self.output_dim)(emb)
+
+ return emb
+
+
+class ActorCriticRND(nn.Module):
+ action_dim: int
+ layer_width: int
+ activation: str = "tanh"
+
+ @nn.compact
+ def __call__(self, x):
+ if self.activation == "relu":
+ activation = nn.relu
+ else:
+ activation = nn.tanh
+
+ actor_mean = nn.Dense(
+ self.layer_width,
+ kernel_init=orthogonal(np.sqrt(2)),
+ bias_init=constant(0.0),
+ )(x)
+ actor_mean = activation(actor_mean)
+
+ actor_mean = nn.Dense(
+ self.layer_width,
+ kernel_init=orthogonal(np.sqrt(2)),
+ bias_init=constant(0.0),
+ )(actor_mean)
+ actor_mean = activation(actor_mean)
+
+ actor_mean = nn.Dense(
+ self.layer_width,
+ kernel_init=orthogonal(np.sqrt(2)),
+ bias_init=constant(0.0),
+ )(actor_mean)
+ actor_mean = activation(actor_mean)
+
+ actor_mean = nn.Dense(
+ self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0)
+ )(actor_mean)
+ pi = distrax.Categorical(logits=actor_mean)
+
+ # Extrinsic reward
+ critic_e = nn.Dense(
+ self.layer_width,
+ kernel_init=orthogonal(np.sqrt(2)),
+ bias_init=constant(0.0),
+ )(x)
+ critic_e = activation(critic_e)
+
+ critic_e = nn.Dense(
+ self.layer_width,
+ kernel_init=orthogonal(np.sqrt(2)),
+ bias_init=constant(0.0),
+ )(critic_e)
+ critic_e = activation(critic_e)
+
+ critic_e = nn.Dense(
+ self.layer_width,
+ kernel_init=orthogonal(np.sqrt(2)),
+ bias_init=constant(0.0),
+ )(critic_e)
+ critic_e = activation(critic_e)
+
+ critic_e = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))(
+ critic_e
+ )
+
+ # Intrinsic reward
+ critic_i = nn.Dense(
+ self.layer_width,
+ kernel_init=orthogonal(np.sqrt(2)),
+ bias_init=constant(0.0),
+ )(x)
+ critic_i = activation(critic_i)
+
+ critic_i = nn.Dense(
+ self.layer_width,
+ kernel_init=orthogonal(np.sqrt(2)),
+ bias_init=constant(0.0),
+ )(critic_i)
+ critic_i = activation(critic_i)
+
+ critic_i = nn.Dense(
+ self.layer_width,
+ kernel_init=orthogonal(np.sqrt(2)),
+ bias_init=constant(0.0),
+ )(critic_i)
+ critic_i = activation(critic_i)
+
+ critic_i = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))(
+ critic_i
+ )
+
+ return pi, jnp.squeeze(critic_e, axis=-1), jnp.squeeze(critic_i, axis=-1)
diff --git a/Craftax_Baselines/ppo.py b/Craftax_Baselines/ppo.py
new file mode 100644
index 0000000000000000000000000000000000000000..403bcf8d9d989f050584aa996f600213853dad7e
--- /dev/null
+++ b/Craftax_Baselines/ppo.py
@@ -0,0 +1,733 @@
+import argparse
+import os
+import sys
+import time
+
+import jax
+import jax.numpy as jnp
+import numpy as np
+import optax
+from craftax.craftax_env import make_craftax_env_from_name
+
+import wandb
+from typing import NamedTuple
+
+from flax.training.train_state import TrainState
+import orbax.checkpoint as ocp
+
+from logz.batch_logging import batch_log, create_log_dict
+from models.actor_critic import (
+ ActorCritic,
+ ActorCriticConv,
+)
+from models.icm import ICMEncoder, ICMForward, ICMInverse
+from wrappers import (
+ LogWrapper,
+ OptimisticResetVecEnvWrapper,
+ BatchEnvWrapper,
+ AutoResetEnvWrapper,
+)
+
+# Code adapted from the original implementation made by Chris Lu
+# Original code located at https://github.com/luchris429/purejaxrl
+
+
+class Transition(NamedTuple):
+ done: jnp.ndarray
+ action: jnp.ndarray
+ value: jnp.ndarray
+ reward_e: jnp.ndarray
+ reward_i: jnp.ndarray
+ reward: jnp.ndarray
+ log_prob: jnp.ndarray
+ obs: jnp.ndarray
+ next_obs: jnp.ndarray
+ info: jnp.ndarray
+
+
+def make_train(config):
+ config["NUM_UPDATES"] = (
+ config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"]
+ )
+ config["MINIBATCH_SIZE"] = (
+ config["NUM_ENVS"] * config["NUM_STEPS"] // config["NUM_MINIBATCHES"]
+ )
+
+ env = make_craftax_env_from_name(
+ config["ENV_NAME"], not config["USE_OPTIMISTIC_RESETS"]
+ )
+ env_params = env.default_params
+
+ env = LogWrapper(env)
+ if config["USE_OPTIMISTIC_RESETS"]:
+ env = OptimisticResetVecEnvWrapper(
+ env,
+ num_envs=config["NUM_ENVS"],
+ reset_ratio=min(config["OPTIMISTIC_RESET_RATIO"], config["NUM_ENVS"]),
+ )
+ else:
+ env = AutoResetEnvWrapper(env)
+ env = BatchEnvWrapper(env, num_envs=config["NUM_ENVS"])
+
+ def linear_schedule(count):
+ frac = (
+ 1.0
+ - (count // (config["NUM_MINIBATCHES"] * config["UPDATE_EPOCHS"]))
+ / config["NUM_UPDATES"]
+ )
+ return config["LR"] * frac
+
+ def train(rng):
+ # INIT NETWORK
+ if "Symbolic" in config["ENV_NAME"]:
+ network = ActorCritic(env.action_space(env_params).n, config["LAYER_SIZE"])
+ else:
+ network = ActorCriticConv(
+ env.action_space(env_params).n, config["LAYER_SIZE"]
+ )
+
+ rng, _rng = jax.random.split(rng)
+ init_x = jnp.zeros((1, *env.observation_space(env_params).shape))
+ network_params = network.init(_rng, init_x)
+ if config["ANNEAL_LR"]:
+ tx = optax.chain(
+ optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
+ optax.adam(learning_rate=linear_schedule, eps=1e-5),
+ )
+ else:
+ tx = optax.chain(
+ optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
+ optax.adam(config["LR"], eps=1e-5),
+ )
+ train_state = TrainState.create(
+ apply_fn=network.apply,
+ params=network_params,
+ tx=tx,
+ )
+
+ # Exploration state
+ ex_state = {
+ "icm_encoder": None,
+ "icm_forward": None,
+ "icm_inverse": None,
+ "e3b_matrix": None,
+ }
+
+ if config["TRAIN_ICM"]:
+ obs_shape = env.observation_space(env_params).shape
+ assert len(obs_shape) == 1, "Only configured for 1D observations"
+ obs_shape = obs_shape[0]
+
+ # Encoder
+ icm_encoder_network = ICMEncoder(
+ num_layers=3,
+ output_dim=config["ICM_LATENT_SIZE"],
+ layer_size=config["ICM_LAYER_SIZE"],
+ )
+ rng, _rng = jax.random.split(rng)
+ icm_encoder_network_params = icm_encoder_network.init(
+ _rng, jnp.zeros((1, obs_shape))
+ )
+ tx = optax.chain(
+ optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
+ optax.adam(config["ICM_LR"], eps=1e-5),
+ )
+ ex_state["icm_encoder"] = TrainState.create(
+ apply_fn=icm_encoder_network.apply,
+ params=icm_encoder_network_params,
+ tx=tx,
+ )
+
+ # Forward
+ icm_forward_network = ICMForward(
+ num_layers=3,
+ output_dim=config["ICM_LATENT_SIZE"],
+ layer_size=config["ICM_LAYER_SIZE"],
+ num_actions=env.num_actions,
+ )
+ rng, _rng = jax.random.split(rng)
+ icm_forward_network_params = icm_forward_network.init(
+ _rng, jnp.zeros((1, config["ICM_LATENT_SIZE"])), jnp.zeros((1,))
+ )
+ tx = optax.chain(
+ optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
+ optax.adam(config["ICM_LR"], eps=1e-5),
+ )
+ ex_state["icm_forward"] = TrainState.create(
+ apply_fn=icm_forward_network.apply,
+ params=icm_forward_network_params,
+ tx=tx,
+ )
+
+ # Inverse
+ icm_inverse_network = ICMInverse(
+ num_layers=3,
+ output_dim=env.num_actions,
+ layer_size=config["ICM_LAYER_SIZE"],
+ )
+ rng, _rng = jax.random.split(rng)
+ icm_inverse_network_params = icm_inverse_network.init(
+ _rng,
+ jnp.zeros((1, config["ICM_LATENT_SIZE"])),
+ jnp.zeros((1, config["ICM_LATENT_SIZE"])),
+ )
+ tx = optax.chain(
+ optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
+ optax.adam(config["ICM_LR"], eps=1e-5),
+ )
+ ex_state["icm_inverse"] = TrainState.create(
+ apply_fn=icm_inverse_network.apply,
+ params=icm_inverse_network_params,
+ tx=tx,
+ )
+
+ if config["USE_E3B"]:
+ ex_state["e3b_matrix"] = (
+ jnp.repeat(
+ jnp.expand_dims(
+ jnp.identity(config["ICM_LATENT_SIZE"]), axis=0
+ ),
+ config["NUM_ENVS"],
+ axis=0,
+ )
+ / config["E3B_LAMBDA"]
+ )
+
+ # INIT ENV
+ rng, _rng = jax.random.split(rng)
+ obsv, env_state = env.reset(_rng, env_params)
+
+ # TRAIN LOOP
+ def _update_step(runner_state, unused):
+ # COLLECT TRAJECTORIES
+ def _env_step(runner_state, unused):
+ (
+ train_state,
+ env_state,
+ last_obs,
+ ex_state,
+ rng,
+ update_step,
+ ) = runner_state
+
+ # SELECT ACTION
+ rng, _rng = jax.random.split(rng)
+ pi, value = network.apply(train_state.params, last_obs)
+ action = pi.sample(seed=_rng)
+ log_prob = pi.log_prob(action)
+
+ # STEP ENV
+ rng, _rng = jax.random.split(rng)
+ obsv, env_state, reward_e, done, info = env.step(
+ _rng, env_state, action, env_params
+ )
+
+ reward_i = jnp.zeros(config["NUM_ENVS"])
+
+ if config["TRAIN_ICM"]:
+ latent_obs = ex_state["icm_encoder"].apply_fn(
+ ex_state["icm_encoder"].params, last_obs
+ )
+ latent_next_obs = ex_state["icm_encoder"].apply_fn(
+ ex_state["icm_encoder"].params, obsv
+ )
+
+ latent_next_obs_pred = ex_state["icm_forward"].apply_fn(
+ ex_state["icm_forward"].params, latent_obs, action
+ )
+ error = (latent_next_obs - latent_next_obs_pred) * (
+ 1 - done[:, None]
+ )
+ mse = jnp.square(error).mean(axis=-1)
+
+ reward_i = mse * config["ICM_REWARD_COEFF"]
+
+ if config["USE_E3B"]:
+ # Embedding is (NUM_ENVS, 128)
+ # e3b_matrix is (NUM_ENVS, 128, 128)
+ us = jax.vmap(jnp.matmul)(ex_state["e3b_matrix"], latent_obs)
+ bs = jax.vmap(jnp.dot)(latent_obs, us)
+
+ def update_c(c, b, u):
+ return c - (1.0 / (1 + b)) * jnp.outer(u, u)
+
+ updated_cs = jax.vmap(update_c)(ex_state["e3b_matrix"], bs, us)
+ new_cs = (
+ jnp.repeat(
+ jnp.expand_dims(
+ jnp.identity(config["ICM_LATENT_SIZE"]), axis=0
+ ),
+ config["NUM_ENVS"],
+ axis=0,
+ )
+ / config["E3B_LAMBDA"]
+ )
+ ex_state["e3b_matrix"] = jnp.where(
+ done[:, None, None], new_cs, updated_cs
+ )
+
+ e3b_bonus = jnp.where(
+ done, jnp.zeros((config["NUM_ENVS"],)), bs
+ )
+
+ reward_i = e3b_bonus * config["E3B_REWARD_COEFF"]
+
+ reward = reward_e + reward_i
+
+ transition = Transition(
+ done=done,
+ action=action,
+ value=value,
+ reward=reward,
+ reward_i=reward_i,
+ reward_e=reward_e,
+ log_prob=log_prob,
+ obs=last_obs,
+ next_obs=obsv,
+ info=info,
+ )
+ runner_state = (
+ train_state,
+ env_state,
+ obsv,
+ ex_state,
+ rng,
+ update_step,
+ )
+ return runner_state, transition
+
+ runner_state, traj_batch = jax.lax.scan(
+ _env_step, runner_state, None, config["NUM_STEPS"]
+ )
+
+ # CALCULATE ADVANTAGE
+ (
+ train_state,
+ env_state,
+ last_obs,
+ ex_state,
+ rng,
+ update_step,
+ ) = runner_state
+ _, last_val = network.apply(train_state.params, last_obs)
+
+ def _calculate_gae(traj_batch, last_val):
+ def _get_advantages(gae_and_next_value, transition):
+ gae, next_value = gae_and_next_value
+ done, value, reward = (
+ transition.done,
+ transition.value,
+ transition.reward,
+ )
+ delta = reward + config["GAMMA"] * next_value * (1 - done) - value
+ gae = (
+ delta
+ + config["GAMMA"] * config["GAE_LAMBDA"] * (1 - done) * gae
+ )
+ return (gae, value), gae
+
+ _, advantages = jax.lax.scan(
+ _get_advantages,
+ (jnp.zeros_like(last_val), last_val),
+ traj_batch,
+ reverse=True,
+ unroll=16,
+ )
+ return advantages, advantages + traj_batch.value
+
+ advantages, targets = _calculate_gae(traj_batch, last_val)
+
+ # UPDATE NETWORK
+ def _update_epoch(update_state, unused):
+ def _update_minbatch(train_state, batch_info):
+ traj_batch, advantages, targets = batch_info
+
+ # Policy/value network
+ def _loss_fn(params, traj_batch, gae, targets):
+ # RERUN NETWORK
+ pi, value = network.apply(params, traj_batch.obs)
+ log_prob = pi.log_prob(traj_batch.action)
+
+ # CALCULATE VALUE LOSS
+ value_pred_clipped = traj_batch.value + (
+ value - traj_batch.value
+ ).clip(-config["CLIP_EPS"], config["CLIP_EPS"])
+ value_losses = jnp.square(value - targets)
+ value_losses_clipped = jnp.square(value_pred_clipped - targets)
+ value_loss = (
+ 0.5 * jnp.maximum(value_losses, value_losses_clipped).mean()
+ )
+
+ # CALCULATE ACTOR LOSS
+ ratio = jnp.exp(log_prob - traj_batch.log_prob)
+ gae = (gae - gae.mean()) / (gae.std() + 1e-8)
+ loss_actor1 = ratio * gae
+ loss_actor2 = (
+ jnp.clip(
+ ratio,
+ 1.0 - config["CLIP_EPS"],
+ 1.0 + config["CLIP_EPS"],
+ )
+ * gae
+ )
+ loss_actor = -jnp.minimum(loss_actor1, loss_actor2)
+ loss_actor = loss_actor.mean()
+ entropy = pi.entropy().mean()
+
+ total_loss = (
+ loss_actor
+ + config["VF_COEF"] * value_loss
+ - config["ENT_COEF"] * entropy
+ )
+ return total_loss, (value_loss, loss_actor, entropy)
+
+ grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
+ total_loss, grads = grad_fn(
+ train_state.params, traj_batch, advantages, targets
+ )
+ train_state = train_state.apply_gradients(grads=grads)
+
+ losses = (total_loss, 0)
+ return train_state, losses
+
+ (
+ train_state,
+ traj_batch,
+ advantages,
+ targets,
+ rng,
+ ) = update_state
+ rng, _rng = jax.random.split(rng)
+ batch_size = config["MINIBATCH_SIZE"] * config["NUM_MINIBATCHES"]
+ assert (
+ batch_size == config["NUM_STEPS"] * config["NUM_ENVS"]
+ ), "batch size must be equal to number of steps * number of envs"
+ permutation = jax.random.permutation(_rng, batch_size)
+ batch = (traj_batch, advantages, targets)
+ batch = jax.tree.map(
+ lambda x: x.reshape((batch_size,) + x.shape[2:]), batch
+ )
+ shuffled_batch = jax.tree.map(
+ lambda x: jnp.take(x, permutation, axis=0), batch
+ )
+ minibatches = jax.tree.map(
+ lambda x: jnp.reshape(
+ x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:])
+ ),
+ shuffled_batch,
+ )
+ train_state, losses = jax.lax.scan(
+ _update_minbatch, train_state, minibatches
+ )
+ update_state = (
+ train_state,
+ traj_batch,
+ advantages,
+ targets,
+ rng,
+ )
+ return update_state, losses
+
+ update_state = (
+ train_state,
+ traj_batch,
+ advantages,
+ targets,
+ rng,
+ )
+ update_state, loss_info = jax.lax.scan(
+ _update_epoch, update_state, None, config["UPDATE_EPOCHS"]
+ )
+
+ train_state = update_state[0]
+ metric = jax.tree.map(
+ lambda x: (x * traj_batch.info["returned_episode"]).sum()
+ / traj_batch.info["returned_episode"].sum(),
+ traj_batch.info,
+ )
+
+ rng = update_state[-1]
+
+ # UPDATE EXPLORATION STATE
+ def _update_ex_epoch(update_state, unused):
+ def _update_ex_minbatch(ex_state, traj_batch):
+ def _inverse_loss_fn(
+ icm_encoder_params, icm_inverse_params, traj_batch
+ ):
+ latent_obs = ex_state["icm_encoder"].apply_fn(
+ icm_encoder_params, traj_batch.obs
+ )
+ latent_next_obs = ex_state["icm_encoder"].apply_fn(
+ icm_encoder_params, traj_batch.next_obs
+ )
+
+ action_pred_logits = ex_state["icm_inverse"].apply_fn(
+ icm_inverse_params, latent_obs, latent_next_obs
+ )
+ true_action = jax.nn.one_hot(
+ traj_batch.action, num_classes=action_pred_logits.shape[-1]
+ )
+
+ bce = -jnp.mean(
+ jnp.sum(
+ action_pred_logits
+ * true_action
+ * (1 - traj_batch.done[:, None]),
+ axis=1,
+ )
+ )
+
+ return bce * config["ICM_INVERSE_LOSS_COEF"]
+
+ inverse_grad_fn = jax.value_and_grad(
+ _inverse_loss_fn,
+ has_aux=False,
+ argnums=(
+ 0,
+ 1,
+ ),
+ )
+ inverse_loss, grads = inverse_grad_fn(
+ ex_state["icm_encoder"].params,
+ ex_state["icm_inverse"].params,
+ traj_batch,
+ )
+ icm_encoder_grad, icm_inverse_grad = grads
+ ex_state["icm_encoder"] = ex_state["icm_encoder"].apply_gradients(
+ grads=icm_encoder_grad
+ )
+ ex_state["icm_inverse"] = ex_state["icm_inverse"].apply_gradients(
+ grads=icm_inverse_grad
+ )
+
+ def _forward_loss_fn(icm_forward_params, traj_batch):
+ latent_obs = ex_state["icm_encoder"].apply_fn(
+ ex_state["icm_encoder"].params, traj_batch.obs
+ )
+ latent_next_obs = ex_state["icm_encoder"].apply_fn(
+ ex_state["icm_encoder"].params, traj_batch.next_obs
+ )
+
+ latent_next_obs_pred = ex_state["icm_forward"].apply_fn(
+ icm_forward_params, latent_obs, traj_batch.action
+ )
+
+ error = (latent_next_obs - latent_next_obs_pred) * (
+ 1 - traj_batch.done[:, None]
+ )
+ return (
+ jnp.square(error).mean() * config["ICM_FORWARD_LOSS_COEF"]
+ )
+
+ forward_grad_fn = jax.value_and_grad(
+ _forward_loss_fn, has_aux=False
+ )
+ forward_loss, icm_forward_grad = forward_grad_fn(
+ ex_state["icm_forward"].params, traj_batch
+ )
+ ex_state["icm_forward"] = ex_state["icm_forward"].apply_gradients(
+ grads=icm_forward_grad
+ )
+
+ losses = (inverse_loss, forward_loss)
+ return ex_state, losses
+
+ (ex_state, traj_batch, rng) = update_state
+ rng, _rng = jax.random.split(rng)
+ batch_size = config["MINIBATCH_SIZE"] * config["NUM_MINIBATCHES"]
+ assert (
+ batch_size == config["NUM_STEPS"] * config["NUM_ENVS"]
+ ), "batch size must be equal to number of steps * number of envs"
+ permutation = jax.random.permutation(_rng, batch_size)
+ batch = jax.tree.map(
+ lambda x: x.reshape((batch_size,) + x.shape[2:]), traj_batch
+ )
+ shuffled_batch = jax.tree.map(
+ lambda x: jnp.take(x, permutation, axis=0), batch
+ )
+ minibatches = jax.tree.map(
+ lambda x: jnp.reshape(
+ x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:])
+ ),
+ shuffled_batch,
+ )
+ ex_state, losses = jax.lax.scan(
+ _update_ex_minbatch, ex_state, minibatches
+ )
+ update_state = (ex_state, traj_batch, rng)
+ return update_state, losses
+
+ if config["TRAIN_ICM"]:
+ ex_update_state = (ex_state, traj_batch, rng)
+ ex_update_state, ex_loss = jax.lax.scan(
+ _update_ex_epoch,
+ ex_update_state,
+ None,
+ config["EXPLORATION_UPDATE_EPOCHS"],
+ )
+ metric["icm_inverse_loss"] = ex_loss[0].mean()
+ metric["icm_forward_loss"] = ex_loss[1].mean()
+ metric["reward_i"] = traj_batch.reward_i.mean()
+ metric["reward_e"] = traj_batch.reward_e.mean()
+
+ ex_state = ex_update_state[0]
+ rng = ex_update_state[-1]
+
+ # wandb logging
+ if config["DEBUG"] and config["USE_WANDB"]:
+
+ def callback(metric, update_step):
+ to_log = create_log_dict(metric, config)
+ batch_log(update_step, to_log, config)
+
+ jax.debug.callback(
+ callback,
+ metric,
+ update_step,
+ )
+
+ runner_state = (
+ train_state,
+ env_state,
+ last_obs,
+ ex_state,
+ rng,
+ update_step + 1,
+ )
+ return runner_state, metric
+
+ rng, _rng = jax.random.split(rng)
+ runner_state = (
+ train_state,
+ env_state,
+ obsv,
+ ex_state,
+ _rng,
+ 0,
+ )
+ runner_state, metric = jax.lax.scan(
+ _update_step, runner_state, None, config["NUM_UPDATES"]
+ )
+ return {"runner_state": runner_state} # , "info": metric}
+
+ return train
+
+
+def run_ppo(config):
+ config = {k.upper(): v for k, v in config.__dict__.items()}
+
+ if config["USE_WANDB"]:
+ wandb.init(
+ project=config["WANDB_PROJECT"],
+ entity=config["WANDB_ENTITY"],
+ config=config,
+ name=config["ENV_NAME"]
+ + "-"
+ + str(int(config["TOTAL_TIMESTEPS"] // 1e6))
+ + "M",
+ )
+
+ rng = jax.random.PRNGKey(config["SEED"])
+ rngs = jax.random.split(rng, config["NUM_REPEATS"])
+
+ train_jit = jax.jit(make_train(config))
+ train_vmap = jax.vmap(train_jit)
+
+ t0 = time.time()
+ out = train_vmap(rngs)
+ t1 = time.time()
+ print("Time to run experiment", t1 - t0)
+ print("SPS: ", config["TOTAL_TIMESTEPS"] / (t1 - t0))
+
+ if config["USE_WANDB"]:
+
+ def _save_network(rs_index, dir_name):
+ train_states = out["runner_state"][rs_index]
+ train_state = jax.tree.map(lambda x: x[0], train_states)
+
+ path = os.path.join(wandb.run.dir, dir_name)
+ options = ocp.CheckpointManagerOptions(max_to_keep=1)
+
+ with ocp.CheckpointManager(path, options=options) as checkpoint_manager:
+ checkpoint_manager.save(
+ int(config["TOTAL_TIMESTEPS"]),
+ args=ocp.args.StandardSave(train_state)
+ )
+
+ print(f"saved runner state to {path}")
+
+ if config["SAVE_POLICY"]:
+ _save_network(0, "policies")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--env_name", type=str, default="Craftax-Symbolic-v1")
+ parser.add_argument(
+ "--num_envs",
+ type=int,
+ default=1024,
+ )
+ parser.add_argument(
+ "--total_timesteps", type=lambda x: int(float(x)), default=1e9
+ ) # Allow scientific notation
+ parser.add_argument("--lr", type=float, default=2e-4)
+ parser.add_argument("--num_steps", type=int, default=64)
+ parser.add_argument("--update_epochs", type=int, default=4)
+ parser.add_argument("--num_minibatches", type=int, default=8)
+ parser.add_argument("--gamma", type=float, default=0.99)
+ parser.add_argument("--gae_lambda", type=float, default=0.8)
+ parser.add_argument("--clip_eps", type=float, default=0.2)
+ parser.add_argument("--ent_coef", type=float, default=0.01)
+ parser.add_argument("--vf_coef", type=float, default=0.5)
+ parser.add_argument("--max_grad_norm", type=float, default=1.0)
+ parser.add_argument("--activation", type=str, default="tanh")
+ parser.add_argument(
+ "--anneal_lr", action=argparse.BooleanOptionalAction, default=True
+ )
+ parser.add_argument("--debug", action=argparse.BooleanOptionalAction, default=True)
+ parser.add_argument("--jit", action=argparse.BooleanOptionalAction, default=True)
+ parser.add_argument("--seed", type=int)
+ parser.add_argument(
+ "--use_wandb", action=argparse.BooleanOptionalAction, default=True
+ )
+ parser.add_argument("--save_policy", action="store_true")
+ parser.add_argument("--num_repeats", type=int, default=1)
+ parser.add_argument("--layer_size", type=int, default=512)
+ parser.add_argument("--wandb_project", type=str)
+ parser.add_argument("--wandb_entity", type=str)
+ parser.add_argument(
+ "--use_optimistic_resets", action=argparse.BooleanOptionalAction, default=True
+ )
+ parser.add_argument("--optimistic_reset_ratio", type=int, default=16)
+
+ # EXPLORATION
+ parser.add_argument("--exploration_update_epochs", type=int, default=4)
+ # ICM
+ parser.add_argument("--icm_reward_coeff", type=float, default=1.0)
+ parser.add_argument("--train_icm", action="store_true")
+ parser.add_argument("--icm_lr", type=float, default=3e-4)
+ parser.add_argument("--icm_forward_loss_coef", type=float, default=1.0)
+ parser.add_argument("--icm_inverse_loss_coef", type=float, default=1.0)
+ parser.add_argument("--icm_layer_size", type=int, default=256)
+ parser.add_argument("--icm_latent_size", type=int, default=32)
+ # E3B
+ parser.add_argument("--e3b_reward_coeff", type=float, default=1.0)
+ parser.add_argument("--use_e3b", action="store_true")
+ parser.add_argument("--e3b_lambda", type=float, default=0.1)
+
+ args, rest_args = parser.parse_known_args(sys.argv[1:])
+ if rest_args:
+ raise ValueError(f"Unknown args {rest_args}")
+
+ if args.use_e3b:
+ assert args.train_icm
+ assert args.icm_reward_coeff == 0
+ if args.seed is None:
+ args.seed = np.random.randint(2**31)
+
+ if args.jit:
+ run_ppo(args)
+ else:
+ with jax.disable_jit():
+ run_ppo(args)
diff --git a/Craftax_Baselines/ppo_rnd.py b/Craftax_Baselines/ppo_rnd.py
new file mode 100644
index 0000000000000000000000000000000000000000..7256b5ec716def468d05599a4e7ae68755765944
--- /dev/null
+++ b/Craftax_Baselines/ppo_rnd.py
@@ -0,0 +1,680 @@
+import argparse
+import os
+import sys
+import time
+
+import jax
+import jax.numpy as jnp
+import numpy as np
+import optax
+from craftax.craftax_env import make_craftax_env_from_name
+
+import wandb
+from typing import NamedTuple
+
+from flax.training.train_state import TrainState
+import orbax.checkpoint as ocp
+
+from logz.batch_logging import batch_log, create_log_dict
+from wrappers import (
+ LogWrapper,
+ OptimisticResetVecEnvWrapper,
+ AutoResetEnvWrapper,
+ BatchEnvWrapper,
+)
+from models.rnd import RNDNetwork, ActorCriticRND
+
+# Code adapted from the original implementation made by Chris Lu
+# Original code located at https://github.com/luchris429/purejaxrl
+
+
+class Transition(NamedTuple):
+ done: jnp.ndarray
+ action: jnp.ndarray
+ value_e: jnp.ndarray
+ value_i: jnp.ndarray
+ reward_e: jnp.ndarray
+ reward_i: jnp.ndarray
+ reward: jnp.ndarray
+ log_prob: jnp.ndarray
+ obs: jnp.ndarray
+ next_obs: jnp.ndarray
+ info: jnp.ndarray
+
+
+def make_train(config):
+ config["NUM_UPDATES"] = (
+ config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"]
+ )
+ config["MINIBATCH_SIZE"] = (
+ config["NUM_ENVS"] * config["NUM_STEPS"] // config["NUM_MINIBATCHES"]
+ )
+
+ env = make_craftax_env_from_name(
+ config["ENV_NAME"], not config["USE_OPTIMISTIC_RESETS"]
+ )
+ env_params = env.default_params
+
+ env = LogWrapper(env)
+ if config["USE_OPTIMISTIC_RESETS"]:
+ env = OptimisticResetVecEnvWrapper(
+ env,
+ num_envs=config["NUM_ENVS"],
+ reset_ratio=min(config["OPTIMISTIC_RESET_RATIO"], config["NUM_ENVS"]),
+ )
+ else:
+ env = AutoResetEnvWrapper(env)
+ env = BatchEnvWrapper(env, num_envs=config["NUM_ENVS"])
+
+ def linear_schedule(count):
+ frac = (
+ 1.0
+ - (count // (config["NUM_MINIBATCHES"] * config["UPDATE_EPOCHS"]))
+ / config["NUM_UPDATES"]
+ )
+ return config["LR"] * frac
+
+ def train(rng):
+ # INIT NETWORK
+ if "Symbolic" in config["ENV_NAME"]:
+ network = ActorCriticRND(
+ env.action_space(env_params).n, config["LAYER_SIZE"]
+ )
+ else:
+ raise ValueError
+ # network = ActorCriticConv(
+ # env.action_space(env_params).n, config["LAYER_SIZE"]
+ # )
+
+ rng, _rng = jax.random.split(rng)
+ init_x = jnp.zeros((1, *env.observation_space(env_params).shape))
+ network_params = network.init(_rng, init_x)
+ if config["ANNEAL_LR"]:
+ tx = optax.chain(
+ optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
+ optax.adam(learning_rate=linear_schedule, eps=1e-5),
+ )
+ else:
+ tx = optax.chain(
+ optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
+ optax.adam(config["LR"], eps=1e-5),
+ )
+ train_state = TrainState.create(
+ apply_fn=network.apply,
+ params=network_params,
+ tx=tx,
+ )
+
+ # Exploration state
+ ex_state = {
+ "rnd_model": None,
+ }
+
+ if config["USE_RND"]:
+ obs_shape = env.observation_space(env_params).shape
+ assert len(obs_shape) == 1, "Only configured for 1D observations"
+ obs_shape = obs_shape[0]
+
+ # Random network
+ rnd_random_network = RNDNetwork(
+ num_layers=3,
+ output_dim=config["RND_OUTPUT_SIZE"],
+ layer_size=config["RND_LAYER_SIZE"],
+ )
+ rng, _rng = jax.random.split(rng)
+ rnd_random_network_params = rnd_random_network.init(
+ _rng, jnp.zeros((1, obs_shape))
+ )
+
+ # Distillation Network
+ rnd_distillation_network = RNDNetwork(
+ num_layers=3,
+ output_dim=config["RND_OUTPUT_SIZE"],
+ layer_size=config["RND_LAYER_SIZE"],
+ )
+ rng, _rng = jax.random.split(rng)
+ rnd_distillation_network_params = rnd_distillation_network.init(
+ _rng, jnp.zeros((1, obs_shape))
+ )
+ tx = optax.chain(
+ optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
+ optax.adam(config["RND_LR"], eps=1e-5),
+ )
+ ex_state["rnd_distillation_network"] = TrainState.create(
+ apply_fn=rnd_distillation_network.apply,
+ params=rnd_distillation_network_params,
+ tx=tx,
+ )
+
+ # INIT ENV
+ rng, _rng = jax.random.split(rng)
+ obsv, env_state = env.reset(_rng, env_params)
+
+ # TRAIN LOOP
+ def _update_step(runner_state, unused):
+ # COLLECT TRAJECTORIES
+ def _env_step(runner_state, unused):
+ (
+ train_state,
+ env_state,
+ last_obs,
+ ex_state,
+ rng,
+ update_step,
+ ) = runner_state
+
+ # SELECT ACTION
+ rng, _rng = jax.random.split(rng)
+ pi, value_e, value_i = network.apply(train_state.params, last_obs)
+ action = pi.sample(seed=_rng)
+ log_prob = pi.log_prob(action)
+
+ # STEP ENV
+ rng, _rng = jax.random.split(rng)
+ obsv, env_state, reward_e, done, info = env.step(
+ _rng, env_state, action, env_params
+ )
+
+ reward_i = jnp.zeros(config["NUM_ENVS"])
+
+ if config["USE_RND"]:
+ random_pred = rnd_random_network.apply(
+ rnd_random_network_params, obsv
+ )
+
+ distill_pred = ex_state["rnd_distillation_network"].apply_fn(
+ ex_state["rnd_distillation_network"].params, obsv
+ )
+ error = (random_pred - distill_pred) * (1 - done[:, None])
+ mse = jnp.square(error).mean(axis=-1)
+
+ reward_i = mse * config["RND_REWARD_COEFF"]
+
+ reward = reward_e + reward_i
+
+ transition = Transition(
+ done=done,
+ action=action,
+ value_e=value_e,
+ value_i=value_i,
+ reward=reward,
+ reward_i=reward_i,
+ reward_e=reward_e,
+ log_prob=log_prob,
+ obs=last_obs,
+ next_obs=obsv,
+ info=info,
+ )
+ runner_state = (
+ train_state,
+ env_state,
+ obsv,
+ ex_state,
+ rng,
+ update_step,
+ )
+ return runner_state, transition
+
+ runner_state, traj_batch = jax.lax.scan(
+ _env_step, runner_state, None, config["NUM_STEPS"]
+ )
+
+ # CALCULATE ADVANTAGE
+ (
+ train_state,
+ env_state,
+ last_obs,
+ ex_state,
+ rng,
+ update_step,
+ ) = runner_state
+ _, last_val_e, last_val_i = network.apply(train_state.params, last_obs)
+
+ def _calculate_gae(traj_batch, last_val, is_extrinsic):
+ def _get_advantages(gae_and_next_value, transition):
+ gae, next_value, is_extrinsic = gae_and_next_value
+ done, value, reward = (
+ transition.done,
+ jax.lax.select(
+ is_extrinsic, transition.value_e, transition.value_i
+ ),
+ jax.lax.select(
+ is_extrinsic, transition.reward_e, transition.reward_i
+ ),
+ )
+ done = jnp.logical_and(
+ done, jnp.logical_or(config["RND_IS_EPISODIC"], is_extrinsic)
+ )
+
+ delta = reward + config["GAMMA"] * next_value * (1 - done) - value
+ gae = (
+ delta
+ + config["GAMMA"] * config["GAE_LAMBDA"] * (1 - done) * gae
+ )
+ return (gae, value, is_extrinsic), gae
+
+ _, advantages = jax.lax.scan(
+ _get_advantages,
+ (jnp.zeros_like(last_val), last_val, is_extrinsic),
+ traj_batch,
+ reverse=True,
+ unroll=16,
+ )
+ return advantages, advantages + jax.lax.select(
+ is_extrinsic, traj_batch.value_e, traj_batch.value_i
+ )
+
+ advantages_e, targets_e = _calculate_gae(traj_batch, last_val_e, True)
+ advantages_i, targets_i = _calculate_gae(traj_batch, last_val_i, False)
+
+ # UPDATE NETWORK
+ def _update_epoch(update_state, unused):
+ def _update_minbatch(train_state, batch_info):
+ (
+ traj_batch,
+ advantages_e,
+ targets_e,
+ advantages_i,
+ targets_i,
+ ) = batch_info
+
+ # Policy/value network
+ def _loss_fn(
+ params, traj_batch, gae_e, targets_e, gae_i, targets_i
+ ):
+ # RERUN NETWORK
+ pi, value_e, value_i = network.apply(params, traj_batch.obs)
+ log_prob = pi.log_prob(traj_batch.action)
+
+ # CALCULATE EXTRINSIC VALUE LOSS
+ value_pred_clipped_e = traj_batch.value_e + (
+ value_e - traj_batch.value_e
+ ).clip(-config["CLIP_EPS"], config["CLIP_EPS"])
+ value_losses_e = jnp.square(value_e - targets_e)
+ value_losses_clipped_e = jnp.square(
+ value_pred_clipped_e - targets_e
+ )
+ value_loss_e = (
+ 0.5
+ * jnp.maximum(value_losses_e, value_losses_clipped_e).mean()
+ )
+
+ # CALCULATE INTRINSIC VALUE LOSS
+ value_pred_clipped_i = traj_batch.value_i + (
+ value_i - traj_batch.value_i
+ ).clip(-config["CLIP_EPS"], config["CLIP_EPS"])
+ value_losses_i = jnp.square(value_i - targets_i)
+ value_losses_clipped_i = jnp.square(
+ value_pred_clipped_i - targets_i
+ )
+ value_loss_i = (
+ 0.5
+ * jnp.maximum(value_losses_i, value_losses_clipped_i).mean()
+ )
+
+ # CALCULATE ACTOR LOSS
+ gae = gae_e
+ if config["USE_RND"]:
+ gae += gae_i * config["RND_GAE_COEFF"]
+ ratio = jnp.exp(log_prob - traj_batch.log_prob)
+ gae = (gae - gae.mean()) / (gae.std() + 1e-8)
+ loss_actor1 = ratio * gae
+ loss_actor2 = (
+ jnp.clip(
+ ratio,
+ 1.0 - config["CLIP_EPS"],
+ 1.0 + config["CLIP_EPS"],
+ )
+ * gae
+ )
+ loss_actor = -jnp.minimum(loss_actor1, loss_actor2)
+ loss_actor = loss_actor.mean()
+ entropy = pi.entropy().mean()
+
+ value_loss = value_loss_e
+ if config["USE_RND"]:
+ value_loss += value_loss_i
+
+ total_loss = (
+ loss_actor
+ + config["VF_COEF"] * value_loss
+ - config["ENT_COEF"] * entropy
+ )
+ return total_loss, (
+ value_loss_e,
+ value_loss_i,
+ loss_actor,
+ entropy,
+ )
+
+ grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
+ total_loss, grads = grad_fn(
+ train_state.params,
+ traj_batch,
+ advantages_e,
+ targets_e,
+ advantages_i,
+ targets_i,
+ )
+ train_state = train_state.apply_gradients(grads=grads)
+
+ losses = (total_loss, 0)
+ return train_state, losses
+
+ (
+ train_state,
+ traj_batch,
+ advantages_e,
+ targets_e,
+ advantages_i,
+ targets_i,
+ rng,
+ ) = update_state
+ rng, _rng = jax.random.split(rng)
+ batch_size = config["MINIBATCH_SIZE"] * config["NUM_MINIBATCHES"]
+ assert (
+ batch_size == config["NUM_STEPS"] * config["NUM_ENVS"]
+ ), "batch size must be equal to number of steps * number of envs"
+ permutation = jax.random.permutation(_rng, batch_size)
+ batch = (
+ traj_batch,
+ advantages_e,
+ targets_e,
+ advantages_i,
+ targets_i,
+ )
+ batch = jax.tree.map(
+ lambda x: x.reshape((batch_size,) + x.shape[2:]), batch
+ )
+ shuffled_batch = jax.tree.map(
+ lambda x: jnp.take(x, permutation, axis=0), batch
+ )
+ minibatches = jax.tree.map(
+ lambda x: jnp.reshape(
+ x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:])
+ ),
+ shuffled_batch,
+ )
+ train_state, losses = jax.lax.scan(
+ _update_minbatch, train_state, minibatches
+ )
+ update_state = (
+ train_state,
+ traj_batch,
+ advantages_e,
+ targets_e,
+ advantages_i,
+ targets_i,
+ rng,
+ )
+ return update_state, losses
+
+ update_state = (
+ train_state,
+ traj_batch,
+ advantages_e,
+ targets_e,
+ advantages_i,
+ targets_i,
+ rng,
+ )
+ update_state, loss_info = jax.lax.scan(
+ _update_epoch, update_state, None, config["UPDATE_EPOCHS"]
+ )
+
+ train_state = update_state[0]
+ metric = jax.tree.map(
+ lambda x: (x * traj_batch.info["returned_episode"]).sum()
+ / traj_batch.info["returned_episode"].sum(),
+ traj_batch.info,
+ )
+
+ rng = update_state[-1]
+
+ # UPDATE EXPLORATION STATE
+ def _update_ex_epoch(update_state, unused):
+ def _update_ex_minbatch(ex_state, traj_batch):
+ rnd_loss = 0
+
+ if config["USE_RND"]:
+
+ def _rnd_loss_fn(rnd_distillation_params, traj_batch):
+ random_network_out = rnd_random_network.apply(
+ rnd_random_network_params, traj_batch.next_obs
+ )
+
+ distillation_network_out = ex_state[
+ "rnd_distillation_network"
+ ].apply_fn(rnd_distillation_params, traj_batch.next_obs)
+
+ error = (random_network_out - distillation_network_out) * (
+ 1 - traj_batch.done[:, None]
+ )
+ return jnp.square(error).mean() * config["RND_LOSS_COEFF"]
+
+ rnd_grad_fn = jax.value_and_grad(_rnd_loss_fn, has_aux=False)
+ rnd_loss, rnd_grad = rnd_grad_fn(
+ ex_state["rnd_distillation_network"].params, traj_batch
+ )
+ ex_state["rnd_distillation_network"] = ex_state[
+ "rnd_distillation_network"
+ ].apply_gradients(grads=rnd_grad)
+
+ losses = (rnd_loss,)
+ return ex_state, losses
+
+ (ex_state, traj_batch, rng) = update_state
+ rng, _rng = jax.random.split(rng)
+ batch_size = config["MINIBATCH_SIZE"] * config["NUM_MINIBATCHES"]
+ assert (
+ batch_size == config["NUM_STEPS"] * config["NUM_ENVS"]
+ ), "batch size must be equal to number of steps * number of envs"
+ permutation = jax.random.permutation(_rng, batch_size)
+ batch = jax.tree.map(
+ lambda x: x.reshape((batch_size,) + x.shape[2:]), traj_batch
+ )
+ shuffled_batch = jax.tree.map(
+ lambda x: jnp.take(x, permutation, axis=0), batch
+ )
+ minibatches = jax.tree.map(
+ lambda x: jnp.reshape(
+ x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:])
+ ),
+ shuffled_batch,
+ )
+ ex_state, losses = jax.lax.scan(
+ _update_ex_minbatch, ex_state, minibatches
+ )
+ update_state = (ex_state, traj_batch, rng)
+ return update_state, losses
+
+ if config["USE_RND"]:
+ ex_update_state = (ex_state, traj_batch, rng)
+ ex_update_state, ex_loss = jax.lax.scan(
+ _update_ex_epoch,
+ ex_update_state,
+ None,
+ config["EXPLORATION_UPDATE_EPOCHS"],
+ )
+ metric["rnd_loss"] = ex_loss[0].mean()
+ metric["reward_i"] = traj_batch.reward_i.mean()
+ metric["reward_e"] = traj_batch.reward_e.mean()
+
+ ex_state = ex_update_state[0]
+ rng = ex_update_state[-1]
+
+ # wandb logging
+ if config["DEBUG"] and config["USE_WANDB"]:
+
+ def callback(
+ metric, update_step
+ ): # , loss_info, traj_batch, ex_state, advantages_i, targets_i):
+ to_log = create_log_dict(metric, config)
+ batch_log(update_step, to_log, config)
+
+ jax.debug.callback(
+ callback,
+ metric,
+ update_step,
+ # loss_info, traj_batch, ex_state, advantages_i, targets_i
+ )
+
+ runner_state = (
+ train_state,
+ env_state,
+ last_obs,
+ ex_state,
+ rng,
+ update_step + 1,
+ )
+ return runner_state, metric
+
+ rng, _rng = jax.random.split(rng)
+ runner_state = (
+ train_state,
+ env_state,
+ obsv,
+ ex_state,
+ _rng,
+ 0,
+ )
+ runner_state, metric = jax.lax.scan(
+ _update_step, runner_state, None, config["NUM_UPDATES"]
+ )
+ return {"runner_state": runner_state} # , "info": metric}
+
+ return train
+
+
+def run_ppo(config):
+ config = {k.upper(): v for k, v in config.__dict__.items()}
+
+ if config["USE_WANDB"]:
+ wandb.init(
+ project=config["WANDB_PROJECT"],
+ entity=config["WANDB_ENTITY"],
+ config=config,
+ name=config["ENV_NAME"]
+ + "-PPO_RND-"
+ + str(int(config["TOTAL_TIMESTEPS"] // 1e6))
+ + "M",
+ )
+
+ rng = jax.random.PRNGKey(config["SEED"])
+ rngs = jax.random.split(rng, config["NUM_REPEATS"])
+
+ train_jit = jax.jit(make_train(config))
+ train_vmap = jax.vmap(train_jit)
+
+ t0 = time.time()
+ out = train_vmap(rngs)
+ t1 = time.time()
+ print("Time to run experiment", t1 - t0)
+ print("SPS: ", config["TOTAL_TIMESTEPS"] / (t1 - t0))
+ # t1 = time.time()
+ # out = train_vmap(rngs)
+ # t2 = time.time()
+ # print("t2", t2 - t1)
+ # print("SPS2: ", config["TOTAL_TIMESTEPS"] / (t2 - t1))
+
+ if config["USE_WANDB"]:
+ # if config["DEBUG"] == "end":
+ # info = out["info"]
+ # for update in range(info["timestep"].shape[1]):
+ # if update % 10 == 0:
+ # for repeat in range(info["timestep"].shape[0]):
+ # update_info = jax.tree.map(lambda x: x[repeat, update], info)
+ # to_log = create_log_dict(update_info)
+ # batch_log(update, to_log, config)
+ #
+ # t2 = time.time()
+ # print("Time to log to wandb", t2 - t1)
+
+ def _save_network(rs_index, dir_name):
+ train_states = out["runner_state"][rs_index]
+ train_state = jax.tree.map(lambda x: x[0], train_states)
+
+ path = os.path.join(wandb.run.dir, dir_name)
+ options = ocp.CheckpointManagerOptions(max_to_keep=1)
+
+ with ocp.CheckpointManager(path, options=options) as checkpoint_manager:
+ checkpoint_manager.save(
+ int(config["TOTAL_TIMESTEPS"]),
+ args=ocp.args.StandardSave(train_state)
+ )
+
+ print(f"saved runner state to {path}")
+
+ if config["SAVE_POLICY"]:
+ _save_network(0, "policies")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--env_name", type=str, default="Craftax-Symbolic-v1")
+ parser.add_argument(
+ "--num_envs",
+ type=int,
+ default=1024,
+ )
+ parser.add_argument(
+ "--total_timesteps", type=lambda x: int(float(x)), default=1e9
+ ) # Allow scientific notation
+ parser.add_argument("--lr", type=float, default=2e-4)
+ parser.add_argument("--num_steps", type=int, default=64)
+ parser.add_argument("--update_epochs", type=int, default=4)
+ parser.add_argument("--num_minibatches", type=int, default=8)
+ parser.add_argument("--gamma", type=float, default=0.99)
+ parser.add_argument("--gae_lambda", type=float, default=0.8)
+ parser.add_argument("--clip_eps", type=float, default=0.2)
+ parser.add_argument("--ent_coef", type=float, default=0.01)
+ parser.add_argument("--vf_coef", type=float, default=0.5)
+ parser.add_argument("--max_grad_norm", type=float, default=1.0)
+ parser.add_argument("--activation", type=str, default="tanh")
+ parser.add_argument(
+ "--anneal_lr", action=argparse.BooleanOptionalAction, default=True
+ )
+ parser.add_argument("--debug", action=argparse.BooleanOptionalAction, default=True)
+ parser.add_argument("--jit", action=argparse.BooleanOptionalAction, default=True)
+ parser.add_argument("--seed", type=int)
+ parser.add_argument(
+ "--use_wandb", action=argparse.BooleanOptionalAction, default=True
+ )
+ parser.add_argument("--save_policy", action="store_true")
+ parser.add_argument("--num_repeats", type=int, default=1)
+ parser.add_argument("--layer_size", type=int, default=512)
+ parser.add_argument("--wandb_project", type=str)
+ parser.add_argument("--wandb_entity", type=str)
+ parser.add_argument(
+ "--use_optimistic_resets", action=argparse.BooleanOptionalAction, default=True
+ )
+ parser.add_argument("--optimistic_reset_ratio", type=int, default=16)
+
+ # EXPLORATION
+ parser.add_argument("--exploration_update_epochs", type=int, default=1)
+ # RND
+ parser.add_argument(
+ "--use_rnd", action=argparse.BooleanOptionalAction, default=True
+ )
+ parser.add_argument("--rnd_layer_size", type=int, default=256)
+ parser.add_argument("--rnd_output_size", type=int, default=512)
+ parser.add_argument("--rnd_lr", type=float, default=3e-4)
+ parser.add_argument("--rnd_reward_coeff", type=float, default=1.0)
+ parser.add_argument("--rnd_loss_coeff", type=float, default=0.01)
+ parser.add_argument("--rnd_gae_coeff", type=float, default=0.01)
+ parser.add_argument(
+ "--rnd_is_episodic", action=argparse.BooleanOptionalAction, default=False
+ )
+
+ args, rest_args = parser.parse_known_args(sys.argv[1:])
+ if rest_args:
+ raise ValueError(f"Unknown args {rest_args}")
+
+ if args.seed is None:
+ args.seed = np.random.randint(2**31)
+
+ if args.jit:
+ run_ppo(args)
+ else:
+ with jax.disable_jit():
+ run_ppo(args)
diff --git a/Craftax_Baselines/ppo_rnn.py b/Craftax_Baselines/ppo_rnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..7868e559e14563d05e1b5c7b9429bd38ad1fbedc
--- /dev/null
+++ b/Craftax_Baselines/ppo_rnn.py
@@ -0,0 +1,542 @@
+import argparse
+import os
+import sys
+
+import jax
+import jax.numpy as jnp
+import flax.linen as nn
+import numpy as np
+import optax
+import time
+
+import orbax.checkpoint as ocp
+
+import wandb
+from flax.linen.initializers import constant, orthogonal
+from typing import NamedTuple, Dict
+from flax.training.train_state import TrainState
+import distrax
+import functools
+
+from wrappers import (
+ LogWrapper,
+ OptimisticResetVecEnvWrapper,
+ BatchEnvWrapper,
+ AutoResetEnvWrapper,
+)
+from logz.batch_logging import create_log_dict, batch_log
+
+from craftax.craftax_env import make_craftax_env_from_name
+
+# Code adapted from the original implementation made by Chris Lu
+# Original code located at https://github.com/luchris429/purejaxrl
+
+
+class ScannedRNN(nn.Module):
+ @functools.partial(
+ nn.scan,
+ variable_broadcast="params",
+ in_axes=0,
+ out_axes=0,
+ split_rngs={"params": False},
+ )
+ @nn.compact
+ def __call__(self, carry, x):
+ """Applies the module."""
+ rnn_state = carry
+ ins, resets = x
+ rnn_state = jnp.where(
+ resets[:, np.newaxis],
+ self.initialize_carry(ins.shape[0], ins.shape[1]),
+ rnn_state,
+ )
+ new_rnn_state, y = nn.GRUCell(features=ins.shape[1])(rnn_state, ins)
+ return new_rnn_state, y
+
+ @staticmethod
+ def initialize_carry(batch_size, hidden_size):
+ # Use a dummy key since the default state init fn is just zeros.
+ cell = nn.GRUCell(features=hidden_size)
+ return cell.initialize_carry(jax.random.PRNGKey(0), (batch_size, hidden_size))
+
+
+class ActorCriticRNN(nn.Module):
+ action_dim: int
+ config: Dict
+
+ @nn.compact
+ def __call__(self, hidden, x):
+ obs, dones = x
+ embedding = nn.Dense(
+ self.config["LAYER_SIZE"],
+ kernel_init=orthogonal(np.sqrt(2)),
+ bias_init=constant(0.0),
+ )(obs)
+ embedding = nn.relu(embedding)
+
+ rnn_in = (embedding, dones)
+ hidden, embedding = ScannedRNN()(hidden, rnn_in)
+
+ actor_mean = nn.Dense(
+ self.config["LAYER_SIZE"],
+ kernel_init=orthogonal(2),
+ bias_init=constant(0.0),
+ )(embedding)
+ actor_mean = nn.relu(actor_mean)
+ actor_mean = nn.Dense(
+ self.config["LAYER_SIZE"],
+ kernel_init=orthogonal(2),
+ bias_init=constant(0.0),
+ )(actor_mean)
+ actor_mean = nn.relu(actor_mean)
+ actor_mean = nn.Dense(
+ self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0)
+ )(actor_mean)
+
+ pi = distrax.Categorical(logits=actor_mean)
+
+ critic = nn.Dense(
+ self.config["LAYER_SIZE"],
+ kernel_init=orthogonal(2),
+ bias_init=constant(0.0),
+ )(embedding)
+ critic = nn.relu(critic)
+ critic = nn.Dense(
+ self.config["LAYER_SIZE"],
+ kernel_init=orthogonal(2),
+ bias_init=constant(0.0),
+ )(critic)
+ critic = nn.relu(critic)
+ critic = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))(
+ critic
+ )
+
+ return hidden, pi, jnp.squeeze(critic, axis=-1)
+
+
+class Transition(NamedTuple):
+ done: jnp.ndarray
+ action: jnp.ndarray
+ value: jnp.ndarray
+ reward: jnp.ndarray
+ log_prob: jnp.ndarray
+ obs: jnp.ndarray
+ info: jnp.ndarray
+
+
+def make_train(config):
+ config["NUM_UPDATES"] = (
+ config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"]
+ )
+ config["MINIBATCH_SIZE"] = (
+ config["NUM_ENVS"] * config["NUM_STEPS"] // config["NUM_MINIBATCHES"]
+ )
+
+ # Create environment
+ env = make_craftax_env_from_name(
+ config["ENV_NAME"], not config["USE_OPTIMISTIC_RESETS"]
+ )
+ env_params = env.default_params
+
+ # Wrap with some extra logging
+ env = LogWrapper(env)
+
+ # Wrap with a batcher, maybe using optimistic resets
+ if config["USE_OPTIMISTIC_RESETS"]:
+ env = OptimisticResetVecEnvWrapper(
+ env,
+ num_envs=config["NUM_ENVS"],
+ reset_ratio=min(config["OPTIMISTIC_RESET_RATIO"], config["NUM_ENVS"]),
+ )
+ else:
+ env = AutoResetEnvWrapper(env)
+ env = BatchEnvWrapper(env, num_envs=config["NUM_ENVS"])
+
+ def linear_schedule(count):
+ frac = (
+ 1.0
+ - (count // (config["NUM_MINIBATCHES"] * config["UPDATE_EPOCHS"]))
+ / config["NUM_UPDATES"]
+ )
+ return config["LR"] * frac
+
+ def train(rng):
+ # INIT NETWORK
+ network = ActorCriticRNN(env.action_space(env_params).n, config=config)
+ rng, _rng = jax.random.split(rng)
+ init_x = (
+ jnp.zeros(
+ (1, config["NUM_ENVS"], *env.observation_space(env_params).shape)
+ ),
+ jnp.zeros((1, config["NUM_ENVS"])),
+ )
+ init_hstate = ScannedRNN.initialize_carry(
+ config["NUM_ENVS"], config["LAYER_SIZE"]
+ )
+ network_params = network.init(_rng, init_hstate, init_x)
+ if config["ANNEAL_LR"]:
+ tx = optax.chain(
+ optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
+ optax.adam(learning_rate=linear_schedule, eps=1e-5),
+ )
+ else:
+ tx = optax.chain(
+ optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
+ optax.adam(config["LR"], eps=1e-5),
+ )
+ train_state = TrainState.create(
+ apply_fn=network.apply,
+ params=network_params,
+ tx=tx,
+ )
+
+ # INIT ENV
+ rng, _rng = jax.random.split(rng)
+ obsv, env_state = env.reset(_rng, env_params)
+ init_hstate = ScannedRNN.initialize_carry(
+ config["NUM_ENVS"], config["LAYER_SIZE"]
+ )
+
+ # TRAIN LOOP
+ def _update_step(runner_state, unused):
+ # COLLECT TRAJECTORIES
+ def _env_step(runner_state, unused):
+ (
+ train_state,
+ env_state,
+ last_obs,
+ last_done,
+ hstate,
+ rng,
+ update_step,
+ ) = runner_state
+ rng, _rng = jax.random.split(rng)
+
+ # SELECT ACTION
+ ac_in = (last_obs[np.newaxis, :], last_done[np.newaxis, :])
+ hstate, pi, value = network.apply(train_state.params, hstate, ac_in)
+ action = pi.sample(seed=_rng)
+ log_prob = pi.log_prob(action)
+ value, action, log_prob = (
+ value.squeeze(0),
+ action.squeeze(0),
+ log_prob.squeeze(0),
+ )
+
+ # STEP ENV
+ rng, _rng = jax.random.split(rng)
+ obsv, env_state, reward, done, info = env.step(
+ _rng, env_state, action, env_params
+ )
+ transition = Transition(
+ last_done, action, value, reward, log_prob, last_obs, info
+ )
+ runner_state = (
+ train_state,
+ env_state,
+ obsv,
+ done,
+ hstate,
+ rng,
+ update_step,
+ )
+ return runner_state, transition
+
+ initial_hstate = runner_state[-3]
+ runner_state, traj_batch = jax.lax.scan(
+ _env_step, runner_state, None, config["NUM_STEPS"]
+ )
+
+ # CALCULATE ADVANTAGE
+ (
+ train_state,
+ env_state,
+ last_obs,
+ last_done,
+ hstate,
+ rng,
+ update_step,
+ ) = runner_state
+ ac_in = (last_obs[np.newaxis, :], last_done[np.newaxis, :])
+ _, _, last_val = network.apply(train_state.params, hstate, ac_in)
+ last_val = last_val.squeeze(0)
+
+ def _calculate_gae(traj_batch, last_val, last_done):
+ def _get_advantages(carry, transition):
+ gae, next_value, next_done = carry
+ done, value, reward = (
+ transition.done,
+ transition.value,
+ transition.reward,
+ )
+ delta = (
+ reward + config["GAMMA"] * next_value * (1 - next_done) - value
+ )
+ gae = (
+ delta
+ + config["GAMMA"] * config["GAE_LAMBDA"] * (1 - next_done) * gae
+ )
+ return (gae, value, done), gae
+
+ _, advantages = jax.lax.scan(
+ _get_advantages,
+ (jnp.zeros_like(last_val), last_val, last_done),
+ traj_batch,
+ reverse=True,
+ unroll=16,
+ )
+ return advantages, advantages + traj_batch.value
+
+ advantages, targets = _calculate_gae(traj_batch, last_val, last_done)
+
+ # UPDATE NETWORK
+ def _update_epoch(update_state, unused):
+ def _update_minbatch(train_state, batch_info):
+ init_hstate, traj_batch, advantages, targets = batch_info
+
+ def _loss_fn(params, init_hstate, traj_batch, gae, targets):
+ # RERUN NETWORK
+ _, pi, value = network.apply(
+ params, init_hstate[0], (traj_batch.obs, traj_batch.done)
+ )
+ log_prob = pi.log_prob(traj_batch.action)
+
+ # CALCULATE VALUE LOSS
+ value_pred_clipped = traj_batch.value + (
+ value - traj_batch.value
+ ).clip(-config["CLIP_EPS"], config["CLIP_EPS"])
+ value_losses = jnp.square(value - targets)
+ value_losses_clipped = jnp.square(value_pred_clipped - targets)
+ value_loss = (
+ 0.5 * jnp.maximum(value_losses, value_losses_clipped).mean()
+ )
+
+ # CALCULATE ACTOR LOSS
+ ratio = jnp.exp(log_prob - traj_batch.log_prob)
+ gae = (gae - gae.mean()) / (gae.std() + 1e-8)
+ loss_actor1 = ratio * gae
+ loss_actor2 = (
+ jnp.clip(
+ ratio,
+ 1.0 - config["CLIP_EPS"],
+ 1.0 + config["CLIP_EPS"],
+ )
+ * gae
+ )
+ loss_actor = -jnp.minimum(loss_actor1, loss_actor2)
+ loss_actor = loss_actor.mean()
+ entropy = pi.entropy().mean()
+
+ total_loss = (
+ loss_actor
+ + config["VF_COEF"] * value_loss
+ - config["ENT_COEF"] * entropy
+ )
+ return total_loss, (value_loss, loss_actor, entropy)
+
+ grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
+ total_loss, grads = grad_fn(
+ train_state.params, init_hstate, traj_batch, advantages, targets
+ )
+ train_state = train_state.apply_gradients(grads=grads)
+ return train_state, total_loss
+
+ (
+ train_state,
+ init_hstate,
+ traj_batch,
+ advantages,
+ targets,
+ rng,
+ ) = update_state
+
+ rng, _rng = jax.random.split(rng)
+ permutation = jax.random.permutation(_rng, config["NUM_ENVS"])
+ batch = (init_hstate, traj_batch, advantages, targets)
+
+ shuffled_batch = jax.tree.map(
+ lambda x: jnp.take(x, permutation, axis=1), batch
+ )
+
+ minibatches = jax.tree.map(
+ lambda x: jnp.swapaxes(
+ jnp.reshape(
+ x,
+ [x.shape[0], config["NUM_MINIBATCHES"], -1]
+ + list(x.shape[2:]),
+ ),
+ 1,
+ 0,
+ ),
+ shuffled_batch,
+ )
+
+ train_state, total_loss = jax.lax.scan(
+ _update_minbatch, train_state, minibatches
+ )
+ update_state = (
+ train_state,
+ init_hstate,
+ traj_batch,
+ advantages,
+ targets,
+ rng,
+ )
+ return update_state, total_loss
+
+ init_hstate = initial_hstate[None, :] # TBH
+ update_state = (
+ train_state,
+ init_hstate,
+ traj_batch,
+ advantages,
+ targets,
+ rng,
+ )
+ update_state, loss_info = jax.lax.scan(
+ _update_epoch, update_state, None, config["UPDATE_EPOCHS"]
+ )
+ train_state = update_state[0]
+ metric = jax.tree.map(
+ lambda x: (x * traj_batch.info["returned_episode"]).sum()
+ / traj_batch.info["returned_episode"].sum(),
+ traj_batch.info,
+ )
+ rng = update_state[-1]
+ if config["DEBUG"] and config["USE_WANDB"]:
+
+ def callback(metric, update_step):
+ to_log = create_log_dict(metric, config)
+ batch_log(update_step, to_log, config)
+
+ jax.debug.callback(callback, metric, update_step)
+
+ runner_state = (
+ train_state,
+ env_state,
+ last_obs,
+ last_done,
+ hstate,
+ rng,
+ update_step + 1,
+ )
+ return runner_state, metric
+
+ rng, _rng = jax.random.split(rng)
+ runner_state = (
+ train_state,
+ env_state,
+ obsv,
+ jnp.zeros((config["NUM_ENVS"]), dtype=bool),
+ init_hstate,
+ _rng,
+ 0,
+ )
+ runner_state, metric = jax.lax.scan(
+ _update_step, runner_state, None, config["NUM_UPDATES"]
+ )
+ return {"runner_state": runner_state, "metric": metric}
+
+ return train
+
+
+def run_ppo(config):
+ config = {k.upper(): v for k, v in config.__dict__.items()}
+
+ if config["USE_WANDB"]:
+ wandb.init(
+ project=config["WANDB_PROJECT"],
+ entity=config["WANDB_ENTITY"],
+ config=config,
+ name=config["ENV_NAME"]
+ + "-PPO_RNN-"
+ + str(int(config["TOTAL_TIMESTEPS"] // 1e6))
+ + "M",
+ )
+
+ rng = jax.random.PRNGKey(config["SEED"])
+ rngs = jax.random.split(rng, config["NUM_REPEATS"])
+
+ train_jit = jax.jit(make_train(config))
+ train_vmap = jax.vmap(train_jit)
+
+ t0 = time.time()
+ out = train_vmap(rngs)
+ t1 = time.time()
+ print("Time to run experiment", t1 - t0)
+ print("SPS: ", config["TOTAL_TIMESTEPS"] / (t1 - t0))
+
+ if config["USE_WANDB"]:
+
+ def _save_network(rs_index, dir_name):
+ train_states = out["runner_state"][rs_index]
+ train_state = jax.tree.map(lambda x: x[0], train_states)
+
+ path = os.path.join(wandb.run.dir, dir_name)
+ options = ocp.CheckpointManagerOptions(max_to_keep=1)
+
+ with ocp.CheckpointManager(path, options=options) as checkpoint_manager:
+ checkpoint_manager.save(
+ int(config["TOTAL_TIMESTEPS"]),
+ args=ocp.args.StandardSave(train_state)
+ )
+
+ print(f"saved runner state to {path}")
+
+ if config["SAVE_POLICY"]:
+ _save_network(0, "policies")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--env_name", type=str, default="Craftax-Symbolic-v1")
+ parser.add_argument(
+ "--num_envs",
+ type=int,
+ default=1024,
+ )
+ parser.add_argument("--total_timesteps", type=lambda x: int(float(x)), default=1e9)
+ parser.add_argument("--lr", type=float, default=2e-4)
+ parser.add_argument("--num_steps", type=int, default=64)
+ parser.add_argument("--update_epochs", type=int, default=4)
+ parser.add_argument("--num_minibatches", type=int, default=8)
+ parser.add_argument("--gamma", type=float, default=0.99)
+ parser.add_argument("--gae_lambda", type=float, default=0.8)
+ parser.add_argument("--clip_eps", type=float, default=0.2)
+ parser.add_argument("--ent_coef", type=float, default=0.01)
+ parser.add_argument("--vf_coef", type=float, default=0.5)
+ parser.add_argument("--max_grad_norm", type=float, default=1.0)
+ parser.add_argument("--activation", type=str, default="tanh")
+ parser.add_argument(
+ "--anneal_lr", action=argparse.BooleanOptionalAction, default=True
+ )
+ parser.add_argument("--debug", action=argparse.BooleanOptionalAction, default=True)
+ parser.add_argument("--jit", action=argparse.BooleanOptionalAction, default=True)
+ parser.add_argument("--seed", type=int, default=np.random.randint(2**31))
+ parser.add_argument(
+ "--use_wandb", action=argparse.BooleanOptionalAction, default=True
+ )
+ parser.add_argument(
+ "--save_policy", action=argparse.BooleanOptionalAction, default=False
+ )
+ parser.add_argument("--num_repeats", type=int, default=1)
+ parser.add_argument("--layer_size", type=int, default=512)
+ parser.add_argument("--wandb_project", type=str)
+ parser.add_argument("--wandb_entity", type=str)
+ parser.add_argument(
+ "--use_optimistic_resets", action=argparse.BooleanOptionalAction, default=True
+ )
+ parser.add_argument("--optimistic_reset_ratio", type=int, default=16)
+
+ args, rest_args = parser.parse_known_args(sys.argv[1:])
+ if rest_args:
+ raise ValueError(f"Unknown args {rest_args}")
+
+ if args.seed is None:
+ args.seed = np.random.randint(2**31)
+
+ if args.jit:
+ run_ppo(args)
+ else:
+ with jax.disable_jit():
+ run_ppo(args)
diff --git a/Craftax_Baselines/requirements.txt b/Craftax_Baselines/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..ed907da69758eef3abbe6c9b9fcca670e9794633
--- /dev/null
+++ b/Craftax_Baselines/requirements.txt
@@ -0,0 +1,16 @@
+jax[cuda12_pip]
+distrax
+optax
+flax
+numpy
+black
+pre-commit
+argparse
+wandb
+orbax-checkpoint==0.5.0
+pygame
+gymnax
+chex
+matplotlib
+imageio
+craftax
diff --git a/Craftax_Baselines/run_docker.sh b/Craftax_Baselines/run_docker.sh
new file mode 100644
index 0000000000000000000000000000000000000000..f22e47879f9809ecbfa3f055e7ff4726ea06c613
--- /dev/null
+++ b/Craftax_Baselines/run_docker.sh
@@ -0,0 +1,24 @@
+#!/bin/bash
+WANDB_API_KEY=$(cat ./wandb_key)
+# git pull
+
+script_and_args="${@:2}"
+if [ $1 == "all" ]; then
+ gpus="0 1 2 3 4 5 6 7"
+else
+ gpus=$1
+fi
+
+for gpu in $gpus; do
+ echo "Launching container craftax_$gpu on GPU $gpu"
+ docker run \
+ --gpus device=$gpu \
+ -e WANDB_API_KEY=$WANDB_API_KEY \
+ -v $(pwd):/home/duser/Craftax \
+ --name craftax_$gpu \
+ --user $(id -u) \
+ --rm \
+ -d \
+ -t craftax_baselines \
+ /bin/bash -c "$script_and_args"
+done
diff --git a/Craftax_Baselines/wrappers.py b/Craftax_Baselines/wrappers.py
new file mode 100644
index 0000000000000000000000000000000000000000..43d601e184f2abd2b6b4c0a1ddc6839460d7db63
--- /dev/null
+++ b/Craftax_Baselines/wrappers.py
@@ -0,0 +1,200 @@
+import jax
+import jax.numpy as jnp
+import chex
+import numpy as np
+from flax import struct
+from functools import partial
+from typing import Optional, Tuple, Union, Any
+
+
+class GymnaxWrapper(object):
+ """Base class for Gymnax wrappers."""
+
+ def __init__(self, env):
+ self._env = env
+
+ # provide proxy access to regular attributes of wrapped object
+ def __getattr__(self, name):
+ return getattr(self._env, name)
+
+
+class BatchEnvWrapper(GymnaxWrapper):
+ """Batches reset and step functions"""
+
+ def __init__(self, env, num_envs: int):
+ super().__init__(env)
+
+ self.num_envs = num_envs
+
+ self.reset_fn = jax.vmap(self._env.reset, in_axes=(0, None))
+ self.step_fn = jax.vmap(self._env.step, in_axes=(0, 0, 0, None))
+
+ @partial(jax.jit, static_argnums=(0, 2))
+ def reset(self, rng, params=None):
+ rng, _rng = jax.random.split(rng)
+ rngs = jax.random.split(_rng, self.num_envs)
+ obs, env_state = self.reset_fn(rngs, params)
+ return obs, env_state
+
+ @partial(jax.jit, static_argnums=(0, 4))
+ def step(self, rng, state, action, params=None):
+ rng, _rng = jax.random.split(rng)
+ rngs = jax.random.split(_rng, self.num_envs)
+ obs, state, reward, done, info = self.step_fn(rngs, state, action, params)
+
+ return obs, state, reward, done, info
+
+
+class AutoResetEnvWrapper(GymnaxWrapper):
+ """Provides standard auto-reset functionality, providing the same behaviour as Gymnax-default."""
+
+ def __init__(self, env):
+ super().__init__(env)
+
+ @partial(jax.jit, static_argnums=(0, 2))
+ def reset(self, key, params=None):
+ return self._env.reset(key, params)
+
+ @partial(jax.jit, static_argnums=(0, 4))
+ def step(self, rng, state, action, params=None):
+
+ rng, _rng = jax.random.split(rng)
+ obs_st, state_st, reward, done, info = self._env.step(
+ _rng, state, action, params
+ )
+
+ rng, _rng = jax.random.split(rng)
+ obs_re, state_re = self._env.reset(_rng, params)
+
+ # Auto-reset environment based on termination
+ def auto_reset(done, state_re, state_st, obs_re, obs_st):
+ state = jax.tree.map(
+ lambda x, y: jax.lax.select(done, x, y), state_re, state_st
+ )
+ obs = jax.lax.select(done, obs_re, obs_st)
+
+ return obs, state
+
+ obs, state = auto_reset(done, state_re, state_st, obs_re, obs_st)
+
+ return obs, state, reward, done, info
+
+
+class OptimisticResetVecEnvWrapper(GymnaxWrapper):
+ """
+ Provides efficient 'optimistic' resets.
+ The wrapper also necessarily handles the batching of environment steps and resetting.
+ reset_ratio: the number of environment workers per environment reset. Higher means more efficient but a higher
+ chance of duplicate resets.
+ """
+
+ def __init__(self, env, num_envs: int, reset_ratio: int):
+ super().__init__(env)
+
+ self.num_envs = num_envs
+ self.reset_ratio = reset_ratio
+ assert (
+ num_envs % reset_ratio == 0
+ ), "Reset ratio must perfectly divide num envs."
+ self.num_resets = self.num_envs // reset_ratio
+
+ self.reset_fn = jax.vmap(self._env.reset, in_axes=(0, None))
+ self.step_fn = jax.vmap(self._env.step, in_axes=(0, 0, 0, None))
+
+ @partial(jax.jit, static_argnums=(0, 2))
+ def reset(self, rng, params=None):
+ rng, _rng = jax.random.split(rng)
+ rngs = jax.random.split(_rng, self.num_envs)
+ obs, env_state = self.reset_fn(rngs, params)
+ return obs, env_state
+
+ @partial(jax.jit, static_argnums=(0, 4))
+ def step(self, rng, state, action, params=None):
+
+ rng, _rng = jax.random.split(rng)
+ rngs = jax.random.split(_rng, self.num_envs)
+ obs_st, state_st, reward, done, info = self.step_fn(rngs, state, action, params)
+
+ rng, _rng = jax.random.split(rng)
+ rngs = jax.random.split(_rng, self.num_resets)
+ obs_re, state_re = self.reset_fn(rngs, params)
+
+ rng, _rng = jax.random.split(rng)
+ reset_indexes = jnp.arange(self.num_resets).repeat(self.reset_ratio)
+
+ being_reset = jax.random.choice(
+ _rng,
+ jnp.arange(self.num_envs),
+ shape=(self.num_resets,),
+ p=done,
+ replace=False,
+ )
+ reset_indexes = reset_indexes.at[being_reset].set(jnp.arange(self.num_resets))
+
+ obs_re = obs_re[reset_indexes]
+ state_re = jax.tree.map(lambda x: x[reset_indexes], state_re)
+
+ # Auto-reset environment based on termination
+ def auto_reset(done, state_re, state_st, obs_re, obs_st):
+ state = jax.tree.map(
+ lambda x, y: jax.lax.select(done, x, y), state_re, state_st
+ )
+ obs = jax.lax.select(done, obs_re, obs_st)
+
+ return state, obs
+
+ state, obs = jax.vmap(auto_reset)(done, state_re, state_st, obs_re, obs_st)
+
+ return obs, state, reward, done, info
+
+
+@struct.dataclass
+class LogEnvState:
+ env_state: Any
+ episode_returns: float
+ episode_lengths: int
+ returned_episode_returns: float
+ returned_episode_lengths: int
+ timestep: int
+
+
+class LogWrapper(GymnaxWrapper):
+ """Log the episode returns and lengths."""
+
+ def __init__(self, env):
+ super().__init__(env)
+
+ @partial(jax.jit, static_argnums=(0, 2))
+ def reset(self, key: chex.PRNGKey, params=None):
+ obs, env_state = self._env.reset(key, params)
+ state = LogEnvState(env_state, 0.0, 0, 0.0, 0, 0)
+ return obs, state
+
+ @partial(jax.jit, static_argnums=(0, 4))
+ def step(
+ self,
+ key: chex.PRNGKey,
+ state,
+ action: Union[int, float],
+ params=None,
+ ):
+ obs, env_state, reward, done, info = self._env.step(
+ key, state.env_state, action, params
+ )
+ new_episode_return = state.episode_returns + reward
+ new_episode_length = state.episode_lengths + 1
+ state = LogEnvState(
+ env_state=env_state,
+ episode_returns=new_episode_return * (1 - done),
+ episode_lengths=new_episode_length * (1 - done),
+ returned_episode_returns=state.returned_episode_returns * (1 - done)
+ + new_episode_return * done,
+ returned_episode_lengths=state.returned_episode_lengths * (1 - done)
+ + new_episode_length * done,
+ timestep=state.timestep + 1,
+ )
+ info["returned_episode_returns"] = state.returned_episode_returns
+ info["returned_episode_lengths"] = state.returned_episode_lengths
+ info["timestep"] = state.timestep
+ info["returned_episode"] = done
+ return obs, state, reward, done, info
diff --git a/README.md b/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..b43bf2c29f63c5bd4fff8ace05c72e9ace33f8e5
--- /dev/null
+++ b/README.md
@@ -0,0 +1,547 @@
+# ReMDM Planner — Discrete Diffusion Planning on Craftax
+
+A JAX implementation of **ReMDM** (Remasking Discrete Diffusion Model) for action-sequence planning in the [Craftax](https://github.com/MichaelTMatthews/Craftax) environment. A bidirectional transformer learns to generate action plans by iteratively denoising masked token sequences, conditioned on the current environment observation.
+
+---
+
+## Description
+
+The planner starts from a fully-masked action sequence and iteratively unmasks tokens over `T` denoising steps, producing a `plan_horizon`-length plan. The ReMDM framework extends standard Masked Discrete Language Modelling (MDLM) with remasking strategies that allow committed tokens to be re-predicted, improving plan coherence.
+
+Two independent training pipelines are available — **Offline BC** and **Online DAgger** — both supervised by a pre-trained PPO expert but otherwise separate. Neither depends on the other; the paper compares them head-to-head.
+
+```
+[Shared] Train PPO agent Craftax_Baselines/ppo_rnn.py | ppo_rnd.py
+ |
+ v checkpoint
+ ┌───────┴────────┐
+ │ │
+ [Offline BC] [Online DAgger]
+ main.py main.py
+ --mode offline --mode online
+ (train on live (train from scratch;
+ PPO rollouts) mixed policy + expert
+ │ labels into replay buffer)
+ v v
+ checkpoint checkpoint
+ │ │
+ └───────┬────────┘
+ v
+[Evaluate] main.py --mode inference --checkpoint_path ...
+
+Optional: an offline BC checkpoint can warm-start DAgger
+via --offline_checkpoint_path (not used in the paper).
+
+ [Offline BC] ──checkpoint──> [Online DAgger]
+```
+
+**Optional utility modes:**
+```
+[Collect] Save PPO rollouts to disk main.py --mode collect
+[Smoke test] Quick end-to-end check main.py --mode smoke
+```
+
+---
+
+## Installation
+
+### Prerequisites (system-level)
+
+`uv` manages Python packages only. The following must be installed at the OS level before
+running on a GPU node — they are **not** in `pyproject.toml`:
+
+- **CUDA 13** driver and toolkit (`libcuda.so`, `libcudnn`)
+
+On HPC clusters these are typically loaded via `module load cuda/13.x`.
+
+### 1. Create the virtual environment
+
+```bash
+# CPU-only (local development / macOS)
+uv sync
+
+# NVIDIA CUDA 13 (GPU node — Linux only)
+uv sync --extra cuda
+
+# Activate
+source .venv/bin/activate
+```
+
+`uv sync` reads `pyproject.toml`, resolves a fully-reproducible lockfile (`uv.lock`),
+and installs into `.venv/`. Commit `uv.lock` to pin the exact dependency graph.
+
+### 2. Initialise the submodule
+
+```bash
+git submodule update --init --recursive
+```
+
+---
+
+## Dependencies
+
+| Package | Version | Role |
+|---------|---------|------|
+| `jax` | >=0.9.2 | JIT compilation and functional arrays |
+| `flax` | >=0.12.6 | Neural network definitions |
+| `optax` | >=0.2.8 | Adam optimiser and gradient clipping |
+| `craftax` | >=1.5.0 | Procedurally-generated Minecraft-like environment |
+| `chex` | >=0.1.91 | JAX testing and assertion utilities |
+| `distrax` | >=0.1.7 | Probability distributions |
+| `orbax` | >=0.1.9 | Model checkpointing |
+| `wandb` | >=0.25.1 | Experiment logging |
+| `numpy` | >=2.4.4 | Array operations |
+| `matplotlib` | >=3.10.8 | Plotting |
+| `polars` | >=1.39.3 | DataFrame analysis |
+| `orjson` | >=3.11.8 | Fast JSON serialisation |
+| `pyyaml` | >=6.0.3 | Config file parsing |
+
+Full specification in `pyproject.toml`. Exact transitive pins are in `uv.lock`.
+
+---
+
+## Usage
+
+All modes share the same entry point. Defaults are loaded from `configs/defaults.yaml`; any value can be overridden on the command line.
+
+```bash
+python main.py --mode [--config PATH] [OVERRIDES...]
+```
+
+Pass `--no-jit` to disable JIT compilation (useful for debugging):
+
+```bash
+python main.py --mode offline --no-jit --num_envs 4
+```
+
+### Stage 1 — Train a PPO agent
+
+PPO training is handled by the `Craftax_Baselines` submodule and produces the checkpoint consumed by all downstream stages.
+
+```bash
+cd Craftax_Baselines
+
+# PPO with GRU hidden state (recommended)
+python ppo_rnn.py \
+ --env_name Craftax-Classic-Symbolic-v1 \
+ --total_timesteps 500000000 \
+ --save_policy --use_wandb
+
+# PPO with Random Network Distillation
+python ppo_rnd.py \
+ --env_name Craftax-Classic-Symbolic-v1 \
+ --total_timesteps 500000000 \
+ --save_policy --use_wandb
+
+cd ..
+```
+
+### Stage 2a — Collect trajectories to disk
+
+Roll out the PPO checkpoint and save `(obs, actions, rewards, dones)` as a `.npz` file for reuse across multiple diffusion training runs.
+
+```bash
+python main.py --mode collect \
+ --ppo_checkpoint_path /path/to/ppo_checkpoint \
+ --offline_data_path data/trajectories.npz \
+ --collect_num_steps 1000000 \
+ --collect_num_envs 128
+```
+
+The file stores arrays shaped `[num_envs, num_iters, ...]`, preserving per-environment contiguity so episode boundaries are respected during window sampling.
+
+### Stage 2b — Train offline from live PPO rollouts
+
+Roll out the PPO agent live at each update step and train the diffusion model on the collected windows. Windows that cross episode boundaries are masked out; windows with higher cumulative reward receive proportionally larger gradient contributions (clipped to `[0.1, return_weight_cap]`).
+
+```bash
+python main.py --mode offline \
+ --ppo_checkpoint_path /path/to/ppo_checkpoint \
+ --offline_total_timesteps 100000000 \
+ --save_policy
+```
+
+### Online DAgger Training
+
+The diffusion model is trained **from scratch** via DAgger (Dataset Aggregation). At each iteration a mixed policy blends the PPO expert and the diffusion learner (controlled by an exponentially decaying `beta`). The mixed policy rolls out trajectories; the expert labels every visited state with the action it would take. These `(obs, expert_plan)` pairs are appended to a growing circular replay buffer, and the diffusion model is trained on the full buffer with the standard MDLM ELBO loss (pure behavioural cloning — no reward weighting).
+
+```bash
+# From scratch (requires PPO expert checkpoint)
+python main.py --mode online \
+ --ppo_checkpoint_path /path/to/ppo_checkpoint \
+ --online_num_updates 1000 \
+ --save_policy
+
+# Optional: warm-start from a pre-trained offline checkpoint
+# (not used in the paper — both methods are compared independently)
+python main.py --mode online \
+ --ppo_checkpoint_path /path/to/ppo_checkpoint \
+ --offline_checkpoint_path /path/to/offline_checkpoint \
+ --online_num_updates 1000 \
+ --save_policy
+```
+
+When `save_policy=true`, online training uploads **two** W&B artifacts: `{env_name}-policy` (final weights) and `{env_name}-policy-best` (weights from the validation iteration with the highest return). Either artifact can be consumed downstream by `--checkpoint_path wandb:…`.
+
+### Stage 4 — Evaluate
+
+```bash
+python main.py --mode inference \
+ --checkpoint_path /path/to/checkpoint \
+ --eval_steps 10000 \
+ --eval_num_envs 32
+```
+
+Prints mean episode return, completed episodes, steps per second, and per-achievement unlock counts. Uses historical inpainting: the first `hist_len` plan positions are locked to observed history.
+
+### Loading checkpoints from W&B artifacts
+
+Any checkpoint path argument (`--checkpoint_path`, `--offline_checkpoint_path`, `--ppo_checkpoint_path`) accepts a W&B artifact reference prefixed with `wandb:`. The artifact is downloaded automatically before training or evaluation begins.
+
+```bash
+# Fully qualified: entity/project/artifact_name:version_or_alias
+python main.py --mode inference \
+ --checkpoint_path wandb:my-team/remdm-craftax/Craftax-Classic-Symbolic-v1-policy:latest
+
+# Online fine-tuning from a W&B offline checkpoint
+python main.py --mode online \
+ --offline_checkpoint_path wandb:my-team/remdm-craftax/Craftax-Classic-Symbolic-v1-policy:v3
+
+# PPO checkpoint from W&B
+python main.py --mode offline \
+ --ppo_checkpoint_path wandb:my-team/ppo-craftax/ppo-rnn-policy:best
+```
+
+Control the download location with `--wandb_download_dir` (defaults to `./artifacts/`).
+
+### Resuming a Training Run
+
+A completed training checkpoint can be used as the starting point for a new run that continues where the previous one left off. This is useful when extending the training budget or when a preempted job needs to be restarted.
+
+**Offline resume:**
+
+```bash
+# Auto-detect step and wandb run ID from checkpoint metadata
+python main.py --mode offline \
+ --ppo_checkpoint_path /path/to/ppo_checkpoint \
+ --resume_checkpoint_path /path/to/completed_offline_checkpoint \
+ --offline_total_timesteps 200000000 \
+ --save_policy
+
+# Explicit step and wandb run ID override
+python main.py --mode offline \
+ --ppo_checkpoint_path /path/to/ppo_checkpoint \
+ --resume_checkpoint_path /path/to/completed_offline_checkpoint \
+ --resume_step 1525 \
+ --resume_wandb_run_id abc123xyz \
+ --offline_total_timesteps 200000000 \
+ --save_policy
+
+# Resume from a W&B artifact
+python main.py --mode offline \
+ --ppo_checkpoint_path /path/to/ppo_checkpoint \
+ --resume_checkpoint_path wandb:my-team/remdm-craftax/policy:latest \
+ --offline_total_timesteps 200000000 \
+ --save_policy
+```
+
+**Online resume:**
+
+```bash
+python main.py --mode online \
+ --ppo_checkpoint_path /path/to/ppo_checkpoint \
+ --resume_checkpoint_path /path/to/completed_online_checkpoint \
+ --online_num_updates 2000 \
+ --save_policy
+```
+
+**Notes:**
+- The DAgger replay buffer is **not** persisted across resumes. It starts empty and refills within the first few iterations.
+- JIT compilation is fully preserved. Resume only affects initialisation outside `jax.jit` (loading checkpoint, setting the optimizer step counter, adjusting scan length).
+- The cosine LR schedule is constructed for the full `num_updates` range. The optimizer step counter is set to the resume offset so the learning rate picks up exactly where the previous run stopped.
+- When `resume_checkpoint_path` points to a checkpoint with a metadata sidecar, `resume_step` and `resume_wandb_run_id` are auto-detected. Explicit CLI flags override the metadata values.
+- Checkpoints without a metadata sidecar (created before this feature) still load; provide `--resume_step` explicitly.
+
+
+---
+
+## Configuration
+
+All hyperparameters are in `configs/defaults.yaml`. Override any value on the command line:
+
+```bash
+python main.py --mode offline --lr 1e-4 --plan_horizon 64 --num_minibatches 16
+```
+
+Point to a custom config file:
+
+```bash
+python main.py --mode online --config configs/my_experiment.yaml
+```
+
+Preset configs for larger runs are provided in `configs/`:
+
+| File | Purpose |
+|------|---------|
+| `configs/defaults.yaml` | Base defaults for all modes |
+| `configs/classic_exp_a_beta_fix.yaml` | Craftax Classic DAgger — beta decay fix only (isolates data quality) |
+| `configs/classic_exp_b_beta_big_model.yaml` | Craftax Classic DAgger — beta fix + 3.5× larger transformer |
+| `configs/classic_exp_c_full_recipe.yaml` | Craftax Classic DAgger — beta + big model + training dynamics |
+| `configs/craftax_exp_a_beta_fix.yaml` | Full Craftax DAgger — beta decay fix only |
+| `configs/craftax_exp_b_beta_big_model.yaml` | Full Craftax DAgger — beta fix + larger transformer |
+| `configs/craftax_exp_c_full_recipe.yaml` | Full Craftax DAgger — full recipe |
+| `configs/final_classic_ucl.yaml` | Final Craftax Classic DAgger — UCL 3090 Ti, seed 42 (produces the Classic checkpoint consumed by the ablation suite) |
+| `configs/final_classic_qmul.yaml` | Env-frame-matched second seed of `final_classic_ucl.yaml` (QMUL H200, seed 43) |
+| `configs/final_craftax_ucl.yaml` | Final Full Craftax DAgger — UCL 4090, seed 42 (produces the Full Craftax checkpoint consumed by the ablation suite) |
+| `configs/final_craftax_qmul.yaml` | Env-frame-matched second seed of `final_craftax_ucl.yaml` (QMUL H200, seed 43) |
+
+RL fine-tuning ablation hyperparameters live under `experiments/rl_finetuning/configs/` and are loaded by `run_ablations.py`, not by `main.py`. See `experiments/README.md`.
+
+The `final_*_qmul.yaml` presets differ from their UCL counterparts only in `num_envs` (smaller partition) and `seed`. All fairness-critical hyperparameters are denominated in env frames or update cycles and automatically rescaled by `resolve_scaled_hyperparams()` at load time, so no manual derivation is needed when moving between hardware tiers.
+
+### Key hyperparameters
+
+**Environment**
+
+| Parameter | Default | Description |
+|-----------|---------|-------------|
+| `env_name` | `Craftax-Classic-Symbolic-v1` | Craftax environment ID. Use `Craftax-Symbolic-v1` for Full Craftax. |
+| `use_optimistic_resets` | `false` | Use `OptimisticResetVecEnvWrapper` instead of `AutoResetEnvWrapper` |
+| `optimistic_reset_ratio` | 16 | Fraction of envs reset per step when optimistic resets are enabled |
+
+**Diffusion model**
+
+| Parameter | Default | Description |
+|-----------|---------|-------------|
+| `plan_horizon` | 32 | Action plan length H |
+| `diffusion_steps` | 15 | Denoising steps T at inference |
+| `diffusion_schedule` | `cosine` | Noise schedule: `cosine` or `linear` |
+| `remask_strategy` | `rescale` | Remasking strategy: `rescale`, `cap`, or `conf` |
+| `train_sigma` | 0.0 | Per-token remasking correction during training (0 = standard MDLM) |
+| `label_smoothing` | 0.0 | Cross-entropy label smoothing epsilon (0 = exact ELBO) |
+| `eta` | 0.5 | Remasking strength |
+| `use_loop` | `true` | Three-phase loop remasking (Algorithm 3) |
+| `t_on` / `t_off` | 0.7 / 0.3 | Time window boundaries for loop remasking |
+| `temperature` | 0.5 | Softmax temperature for token sampling |
+| `top_p` | 0.95 | Nucleus sampling threshold |
+
+**Transformer architecture**
+
+| Parameter | Default | Description |
+|-----------|---------|-------------|
+| `d_model` | 256 | Hidden dimension |
+| `n_heads` | 4 | Attention heads |
+| `n_layers` | 4 | Transformer blocks |
+| `d_ff` | 512 | FFN inner dimension |
+| `obs_encoder_layers` | 2 | MLP layers in the observation encoder |
+| `obs_encoder_width` | 512 | Observation encoder hidden width |
+| `dropout_rate` | 0.1 | Dropout rate (disabled at inference) |
+
+**Offline training**
+
+| Parameter | Default | Description |
+|-----------|---------|-------------|
+| `offline_total_timesteps` | 1e8 | **PRIMARY** env-frame budget for live-PPO data collection. Derives `num_updates` as `offline_total_timesteps // (num_envs * num_steps)`, making the run hardware-portable across `num_envs` changes. |
+| `offline_num_updates` | `null` | **LEGACY** outer update count; used only when `offline_total_timesteps` is unset. |
+| `num_envs` | 1024 | Parallel environments |
+| `num_steps` | 64 | Environment steps collected per update |
+| `num_minibatches` | 8 | Gradient minibatches per epoch |
+| `update_epochs` | 4 | SGD epochs per update step |
+| `num_repeats` | 1 | Independent training seeds (vmapped) |
+| `lr` | 3e-4 | Adam learning rate (cosine-decayed to 10% over all gradient steps) |
+| `lr_warmup_frames` | `null` | **PRIMARY** env-frame warm-up budget. Derives `lr_warmup_steps` as `lr_warmup_frames // (num_envs * num_steps)`. |
+| `lr_warmup_steps` | 0 | **LEGACY** linear warm-up steps before cosine decay (used when `lr_warmup_frames` is unset; 0 = disabled). |
+| `max_grad_norm` | 1.0 | Global gradient clipping norm |
+| `return_weight_cap` | 5.0 | Clip ceiling for per-window return weights (lower clip is fixed at 0.1) |
+| `collect_temperature` | 1.0 | Softmax temperature on PPO logits during live data collection |
+| `val_interval_frames` | `null` | **PRIMARY** env-frames between validation rollouts. Overrides `val_interval` via `val_interval = val_interval_frames // (num_envs * num_steps)`. |
+| `val_interval` | 50 | **LEGACY** validation frequency in update steps (used when `val_interval_frames` is unset). |
+| `val_diffusion_steps` | 50 | Denoising steps used during validation rollouts |
+| `val_replan_every` | 4 | Environment steps executed per diffusion plan during validation |
+| `val_steps` | 128 | Total environment steps per validation rollout |
+
+**Online DAgger training**
+
+| Parameter | Default | Description |
+|-----------|---------|-------------|
+| `online_total_timesteps` | `null` | **PRIMARY** env-frame budget for online DAgger (hardware-portable). Derives `num_updates` as `online_total_timesteps // (num_envs * num_steps)`. |
+| `online_num_updates` | 1000 | **LEGACY** outer DAgger iterations (used when `online_total_timesteps` is unset). |
+| `dagger_beta_init` | 1.0 | Initial expert mixing probability `beta_1` (1.0 = pure expert on the first iteration). |
+| `dagger_beta_final` | `null` | **PRIMARY** target mixing ratio at the end of training. Overrides `dagger_beta_decay` via `decay = (beta_final / beta_init) ** (1 / num_updates)`. |
+| `dagger_beta_decay` | 0.95 | **LEGACY** per-update decay: `beta_i = beta_init * decay^i` (used when `dagger_beta_final` is unset). |
+| `dagger_buffer_cycles` | `null` | **PRIMARY** buffer capacity denominated in update cycles of history (1 cycle = `num_envs * num_steps` frames). Overrides `dagger_buffer_max` via `buffer_max = cycles * (num_envs * num_steps)`. |
+| `dagger_buffer_max` | 100000 | **LEGACY** max samples in the DAgger replay buffer (circular eviction when full). |
+| `dagger_train_passes` | `null` | Passes per update over the aggregated buffer. `null` = 1 pass (matches offline BC per-update gradient work exactly for fair compute comparison). Raise to >1 to trade BC fairness for wider per-update buffer coverage. |
+| `dagger_expert_deterministic` | `true` | If `true`, the PPO expert takes the argmax action (fixed `s → a*` map); if `false`, it samples categorically. Deterministic removes label noise from the aggregated dataset. |
+
+**Data collection**
+
+| Parameter | Default | Description |
+|-----------|---------|-------------|
+| `collect_num_steps` | 10000000 | Total environment steps to collect |
+| `collect_num_envs` | 128 | Parallel environments during collection |
+| `ppo_model_type` | `ppo_rnn` | PPO architecture: `ppo`, `ppo_rnn`, or `ppo_rnd` |
+| `layer_size` | 512 | PPO actor-critic hidden layer width |
+
+**Inference**
+
+| Parameter | Default | Description |
+|-----------|---------|-------------|
+| `eval_steps` | 10000 | Environment steps for evaluation |
+| `eval_num_envs` | 32 | Parallel agents during evaluation (independent of `num_envs`) |
+| `diffusion_steps_eval` | 10 | Denoising steps T used at evaluation time |
+
+**Checkpointing**
+
+| Parameter | Default | Description |
+|-----------|---------|-------------|
+| `save_policy` | `true` | Save final checkpoint at end of training and upload it as a W&B artifact |
+
+**Resume**
+
+| Parameter | Default | Description |
+|-----------|---------|-------------|
+| `resume_checkpoint_path` | `null` | Path to a completed checkpoint to resume from (accepts `wandb:` refs) |
+| `resume_wandb_run_id` | `null` | W&B run ID to resume logging into (auto-read from checkpoint metadata) |
+| `resume_step` | `null` | Update step the checkpoint was saved at (auto-read from checkpoint metadata) |
+
+**Logging**
+
+| Parameter | Default | Description |
+|-----------|---------|-------------|
+| `use_wandb` | `true` | Enable Weights & Biases logging |
+| `wandb_project` | `remdm-craftax` | W&B project name |
+| `wandb_entity` | `"mathis-weil-university-college-london-ucl-"` | W&B entity (team or username) |
+| `wandb_download_dir` | `null` | Download directory for W&B artifacts; null = `./artifacts/` |
+| `seed` | `null` | RNG seed (random if null) |
+
+---
+
+## Remasking Strategies
+
+Controlled by `--remask_strategy`. All strategies operate on top of the three-phase loop controlled by `--use_loop`, `--t_on`, and `--t_off`.
+
+| Strategy | Formula | Description |
+|----------|---------|-------------|
+| `rescale` | `sigma = eta * sigma_max` | Scales maximum remasking probability proportionally |
+| `cap` | `sigma = min(eta, sigma_max)` | Caps remasking at a fixed rate |
+| `conf` | `sigma = eta * sigma_max * (1 - confidence)` | High-confidence tokens are remasked less |
+
+---
+
+## Environment Wrappers
+
+**From `Craftax_Baselines/wrappers.py`** (submodule):
+
+| Wrapper | Purpose |
+|---------|---------|
+| `LogWrapper` | Tracks episode returns and lengths; adds stats to the info dict |
+| `AutoResetEnvWrapper` | Automatically resets episodes on `done` |
+| `BatchEnvWrapper` | Vmaps `reset` and `step` over `num_envs` environments |
+| `OptimisticResetVecEnvWrapper` | Batched resets with reduced overhead; enable via `--use_optimistic_resets` |
+
+**From `src/envs/wrappers.py`**:
+
+| Wrapper | Purpose |
+|---------|---------|
+| `SequenceHistoryWrapper` | Maintains a sliding window of past observations and actions in the env state |
+| `DiscreteTokenizationWrapper` | Quantizes continuous observations into discrete token indices |
+| `PlannerWrapper` | Manages the plan/replan cycle for the diffusion planner |
+| `OfflineTrajectoryWrapper` | Accumulates transitions into a fixed-size circular replay buffer |
+
+**Wrapper stacks:**
+
+```
+Training: env -> LogWrapper -> AutoResetEnvWrapper -> BatchEnvWrapper
+Inference: env -> LogWrapper -> AutoResetEnvWrapper -> BatchEnvWrapper
+```
+
+---
+
+## Project Structure
+
+```
+craftax-ReMDM-planner/
+├── Craftax_Baselines/ # Git submodule — PPO agents and standard wrappers
+│ ├── wrappers.py # LogWrapper, BatchEnvWrapper, AutoResetEnvWrapper, etc.
+│ ├── ppo_rnn.py # PPO-RNN training script
+│ ├── ppo_rnd.py # PPO-RND training script
+│ ├── ppo.py # PPO model definitions
+│ └── models/
+│ ├── actor_critic.py # ActorCritic variants
+│ ├── rnd.py # RND network
+│ └── icm.py # ICM encoder, forward, and inverse networks
+├── configs/
+│ ├── defaults.yaml # Base hyperparameters (CLI-overridable)
+│ ├── classic_exp_a_beta_fix.yaml # Classic DAgger — beta decay fix only
+│ ├── classic_exp_b_beta_big_model.yaml # Classic DAgger — beta fix + big model
+│ ├── classic_exp_c_full_recipe.yaml # Classic DAgger — full recipe
+│ ├── craftax_exp_a_beta_fix.yaml # Full Craftax DAgger — beta fix
+│ ├── craftax_exp_b_beta_big_model.yaml # Full Craftax DAgger — beta + big model
+│ ├── craftax_exp_c_full_recipe.yaml # Full Craftax DAgger — full recipe
+│ ├── final_classic_ucl.yaml # Classic DAgger — UCL 3090 Ti, seed 42
+│ ├── final_classic_qmul.yaml # Classic DAgger — QMUL H200, seed 43
+│ ├── final_craftax_ucl.yaml # Full Craftax DAgger — UCL 4090, seed 42
+│ └── final_craftax_qmul.yaml # Full Craftax DAgger — QMUL H200, seed 43
+├── src/
+│ ├── diffusion/
+│ │ ├── forward.py # Forward masking process q(z_t | x_0)
+│ │ ├── loss.py # Continuous-time MDLM ELBO loss
+│ │ ├── sampling.py # Reverse diffusion with ReMDM remasking
+│ │ └── schedules.py # Linear and cosine noise schedules
+│ ├── models/
+│ │ └── denoiser.py # DenoisingTransformer (obs encoder + transformer)
+│ ├── envs/
+│ │ └── wrappers.py # Sequence, tokenization, planner, and trajectory wrappers
+│ └── planners/
+│ ├── collect.py # --mode collect: PPO rollouts -> .npz
+│ ├── common.py # Shared utilities
+│ ├── env.py # Environment construction
+│ ├── inference.py # --mode inference: MPC evaluation with inpainting
+│ ├── logging.py # Centralised W&B logging utilities
+│ ├── model.py # Diffusion model lifecycle
+│ ├── offline.py # --mode offline: make_train (live PPO rollouts)
+│ ├── online.py # --mode online: DAgger fine-tuning
+│ └── ppo.py # PPO agent adapter and checkpoint loading utilities
+├── experiments/
+│ └── rl_finetuning/ # RL fine-tuning ablation suite (see experiments/README.md)
+│ ├── run_ablations.py # CLI entry point
+│ ├── ablations/ # Loss, optimizer, registry, and training modules
+│ ├── diagnostics/ # Gradient, representation, and timestep diagnostics
+│ ├── analysis/ # Plots, tables, and report generation
+│ └── configs/ # ablations_default.yaml, ablations_fast.yaml,
+│ # ablations_final_{classic,craftax}_{ucl,qmul}.yaml
+├── main.py # CLI entry point
+├── pyproject.toml # uv project — direct deps + tool config
+└── uv.lock # Reproducible lockfile (commit this)
+```
+
+---
+
+## Implementation Notes
+
+**JAX functional purity**: training closures (`make_train`, `make_train_dagger`) are fully JIT-compatible. Environment construction and checkpoint I/O happen outside `jax.jit`.
+
+**Offline training**: `--mode offline` rolls out the PPO agent live at each update step via `make_train`. Use `--mode collect` to save a trajectory `.npz` for inspection or analysis; re-feeding it to `--mode offline` is not supported — pass `--ppo_checkpoint_path` instead.
+
+**Episode-boundary masking**: the offline sampler pre-computes a validity mask over all `(env, time)` positions. A window at `(e, t)` is valid only if `dones[e, t+1:t+H-1]` are all `False`.
+
+**Return weighting**: valid windows are weighted by their cumulative reward, normalised by the batch mean and clipped to `[0.1, RETURN_WEIGHT_CAP]`. Weights are passed as per-sample multipliers into the MDLM loss before reduction, so they correctly scale each sample's gradient contribution.
+
+**LR schedule**: cosine decay from `lr` to `lr * 0.1` over all gradient steps. Set `lr_warmup_frames > 0` (env-frame-invariant, PRIMARY) or `lr_warmup_steps > 0` (LEGACY) to prepend a linear warm-up phase.
+
+**Env-frame-invariant hyperparameters**: the PRIMARY keys `offline_total_timesteps`, `online_total_timesteps`, `lr_warmup_frames`, `val_interval_frames`, `dagger_beta_final`, and `dagger_buffer_cycles` are denominated in env frames (or update cycles). At config load time, `resolve_scaled_hyperparams()` in `src/planners/common.py` converts them to the equivalent update-step-denominated quantities (`num_updates`, `lr_warmup_steps`, `val_interval`, `dagger_beta_decay`, `dagger_buffer_max`) using the current `num_envs * num_steps` frames-per-update. This lets the same config run on different hardware tiers without re-tuning.
+
+**Loss weight clipping**: the MDLM SUBS weight `-alpha'(t) / (1 - alpha_t)` is clipped to 1000 to prevent numerical instability when `alpha_t ≈ 1`.
+
+**Validation rollouts**: during offline training, a held-out rollout runs every `val_interval` steps. It uses the same sampling parameters as inference (`remask_strategy`, `eta`, `use_loop`, `t_on`, `t_off`, `temperature`, `top_p`) with `val_diffusion_steps` denoising steps and `val_replan_every` env steps per plan, for a total of `val_steps` environment steps.
+
+**W&B logging**: all metric aggregation is centralised in `src/planners/logging.py`. Metric namespaces: `diffusion/` (loss, accuracy), `train/` (data quality, throughput), `env/` (episode returns, achievements), `val/` (validation rollouts, emitted every `val_interval` steps), `dagger/` (online DAgger training: beta, buffer fill, reward mean, valid fraction). `train/sps` (environment frames/sec) is only logged in modes that perform live environment interaction.
+
+**DAgger dataset aggregation**: online training (`--mode online`) implements DAgger (Ross et al., 2011). A circular replay buffer accumulates `(obs, expert_plan)` pairs across all iterations. Each update samples uniformly from the full buffer, not just the latest batch. Training samples that cross episode boundaries (any `done` within the plan-horizon window) are marked invalid. The expert (PPO agent) receives correct `done` flags so its RNN hidden state resets on episode boundaries. Windows are extracted with a sliding stride (one per env-time position) rather than stepping the buffer in plan-horizon chunks, so every visited state contributes a label.
+
+**Best-checkpoint tracking**: during online training, the parameters from the validation iteration with the highest validation return are preserved alongside the current live parameters. The final checkpoint and the best-validation checkpoint are both uploaded as separate W&B artifacts (`{env_name}-policy` and `{env_name}-policy-best`).
+
+**Denoising step indexing**: the reverse scan runs from `step_idx = 0` to `T-1`, mapping to diffusion time `t = (T - step_idx) / T` (high noise to low noise).
+
+**Submodule PPO agents**: PPO training lives entirely in `Craftax_Baselines/`. Planner scripts only consume pre-trained checkpoints via `--ppo_checkpoint_path`.
diff --git a/checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/_CHECKPOINT_METADATA b/checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/_CHECKPOINT_METADATA
new file mode 100644
index 0000000000000000000000000000000000000000..af9deeec510ac5142d3a4f9cdbae10bb10eb761a
--- /dev/null
+++ b/checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/_CHECKPOINT_METADATA
@@ -0,0 +1 @@
+{"item_handlers": {"default": "orbax.checkpoint._src.handlers.standard_checkpoint_handler.StandardCheckpointHandler"}, "metrics": {}, "performance_metrics": {}, "init_timestamp_nsecs": 1775663434533263974, "commit_timestamp_nsecs": 1775663435644779625, "custom_metadata": {}}
\ No newline at end of file
diff --git a/checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/_METADATA b/checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/_METADATA
new file mode 100644
index 0000000000000000000000000000000000000000..8d64d23ddc877e9c3a263d5a2d60556a82bc8ff2
--- /dev/null
+++ b/checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/_METADATA
@@ -0,0 +1 @@
+{"tree_metadata": {"('step',)": {"key_metadata": [{"key": "step", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": []}}, "('params', 'params', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('params', 'params', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1345, 768]}}, "('params', 'params', 'Dense_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('params', 'params', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 768]}}, "('params', 'params', 'Dense_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'Dense_2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('params', 'params', 'Dense_3', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_3", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'Dense_3', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_3", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 384]}}, "('params', 'params', 'Dense_4', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_4", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'Dense_4', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_4", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 384]}}, "('params', 'params', 'Dense_5', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_5", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [17]}}, "('params', 'params', 'Dense_5', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_5", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 17]}}, "('params', 'params', 'Embed_0', 'embedding')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Embed_0", "key_type": 2}, {"key": "embedding", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [18, 384]}}, "('params', 'params', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('params', 'params', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('params', 'params', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_0', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('params', 'params', 'TransformerBlock_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 768]}}, "('params', 'params', 'TransformerBlock_0', 'Dense_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('params', 'params', 'TransformerBlock_0', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_0', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_0', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_0', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'key', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('params', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'key', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('params', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'out', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'out', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48, 384]}}, "('params', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'query', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('params', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'query', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('params', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'value', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('params', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'value', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('params', 'params', 'TransformerBlock_1', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('params', 'params', 'TransformerBlock_1', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 768]}}, "('params', 'params', 'TransformerBlock_1', 'Dense_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_1', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('params', 'params', 'TransformerBlock_1', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_1', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_1', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_1', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'key', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('params', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'key', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('params', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'out', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'out', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48, 384]}}, "('params', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'query', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('params', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'query', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('params', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'value', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('params', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'value', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('params', 'params', 'TransformerBlock_2', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('params', 'params', 'TransformerBlock_2', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 768]}}, "('params', 'params', 'TransformerBlock_2', 'Dense_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_2', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('params', 'params', 'TransformerBlock_2', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_2', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_2', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_2', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'key', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('params', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'key', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('params', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'out', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'out', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48, 384]}}, "('params', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'query', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('params', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'query', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('params', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'value', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('params', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'value', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('params', 'params', 'TransformerBlock_3', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('params', 'params', 'TransformerBlock_3', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 768]}}, "('params', 'params', 'TransformerBlock_3', 'Dense_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_3', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('params', 'params', 'TransformerBlock_3', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_3', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_3', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_3', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'key', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('params', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'key', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('params', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'out', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'out', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48, 384]}}, "('params', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'query', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('params', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'query', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('params', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'value', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('params', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'value', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('params', 'params', 'TransformerBlock_4', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('params', 'params', 'TransformerBlock_4', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 768]}}, "('params', 'params', 'TransformerBlock_4', 'Dense_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_4', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('params', 'params', 'TransformerBlock_4', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_4', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_4', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_4', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'key', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('params', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'key', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('params', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'out', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'out', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48, 384]}}, "('params', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'query', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('params', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'query', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('params', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'value', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('params', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'value', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('params', 'params', 'TransformerBlock_5', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('params', 'params', 'TransformerBlock_5', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 768]}}, "('params', 'params', 'TransformerBlock_5', 'Dense_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_5', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('params', 'params', 'TransformerBlock_5', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_5', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_5', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_5', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'key', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('params', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'key', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('params', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'out', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'out', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48, 384]}}, "('params', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'query', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('params', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'query', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('params', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'value', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('params', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'value', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '0')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}], "value_metadata": {"value_type": "None", "skip_deserialize": true}}, "('opt_state', '1', '0', 'count')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "count", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": []}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1345, 768]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 768]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_3', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_3", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_3', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_3", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 384]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_4', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_4", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_4', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_4", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 384]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_5', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_5", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [17]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_5', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_5", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 17]}}, "('opt_state', '1', '0', 'mu', 'params', 'Embed_0', 'embedding')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Embed_0", "key_type": 2}, {"key": "embedding", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [18, 384]}}, "('opt_state', '1', '0', 'mu', 'params', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'mu', 'params', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'mu', 'params', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_0', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 768]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_0', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_0', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_0', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_0', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_0', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'key', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'key', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48, 384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'query', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'query', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'value', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'value', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_1', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_1', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 768]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_1', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_1', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_1', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_1', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_1', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_1', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'key', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'key', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48, 384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'query', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'query', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'value', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'value', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_2', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_2', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 768]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_2', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_2', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_2', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_2', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_2', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_2', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'key', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'key', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48, 384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'query', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'query', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'value', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'value', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_3', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_3', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 768]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_3', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_3', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_3', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_3', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_3', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_3', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'key', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'key', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48, 384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'query', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'query', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'value', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'value', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_4', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_4', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 768]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_4', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_4', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_4', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_4', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_4', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_4', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'key', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'key', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48, 384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'query', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'query', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'value', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'value', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_5', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_5', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 768]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_5', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_5', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_5', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_5', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_5', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_5', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'key', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'key', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48, 384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'query', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'query', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'value', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'value', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1345, 768]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 768]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_3', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_3", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_3', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_3", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 384]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_4', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_4", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_4', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_4", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 384]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_5', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_5", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [17]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_5', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_5", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 17]}}, "('opt_state', '1', '0', 'nu', 'params', 'Embed_0', 'embedding')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Embed_0", "key_type": 2}, {"key": "embedding", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [18, 384]}}, "('opt_state', '1', '0', 'nu', 'params', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'nu', 'params', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'nu', 'params', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_0', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 768]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_0', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_0', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_0', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_0', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_0', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'key', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'key', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48, 384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'query', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'query', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'value', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'value', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_1', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_1', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 768]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_1', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_1', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_1', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_1', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_1', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_1', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'key', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'key', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48, 384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'query', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'query', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'value', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'value', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_2', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_2', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 768]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_2', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_2', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_2', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_2', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_2', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_2', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'key', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'key', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48, 384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'query', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'query', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'value', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'value', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_3', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_3', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 768]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_3', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_3', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_3', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_3', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_3', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_3', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'key', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'key', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48, 384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'query', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'query', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'value', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'value', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_4', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_4', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 768]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_4', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_4', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_4', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_4', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_4', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_4', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'key', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'key', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48, 384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'query', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'query', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'value', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'value', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_5', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_5', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 768]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_5', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_5', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_5', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_5', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_5', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_5', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'key', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'key', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48, 384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'query', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'query', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'value', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'value', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '1', 'count')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "1", "key_type": 1}, {"key": "count", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": []}}}, "use_ocdbt": true, "use_zarr3": false, "store_array_data_equal_to_fill_value": true, "custom_metadata": null}
\ No newline at end of file
diff --git a/checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/_sharding b/checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/_sharding
new file mode 100644
index 0000000000000000000000000000000000000000..c3e2898e48ce0d01971669ae6a0dc6bd5228e39f
--- /dev/null
+++ b/checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/_sharding
@@ -0,0 +1 @@
+{"b3B0X3N0YXRlLjEuMC5jb3VudA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMi5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMi5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMy5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMy5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfNC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfNC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfNS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfNS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRW1iZWRfMC5lbWJlZGRpbmc=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuTGF5ZXJOb3JtXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuTGF5ZXJOb3JtXzAuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuTGF5ZXJOb3JtXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuTGF5ZXJOb3JtXzEuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMi5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMi5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMy5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMy5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfNC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfNC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfNS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfNS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRW1iZWRfMC5lbWJlZGRpbmc=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuTGF5ZXJOb3JtXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuTGF5ZXJOb3JtXzAuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuTGF5ZXJOb3JtXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuTGF5ZXJOb3JtXzEuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMS5jb3VudA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","c3RlcA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV80LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV80Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV81LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV81Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8wLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8xLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8yLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8yLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8zLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8zLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5FbWJlZF8wLmVtYmVkZGluZw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5MYXllck5vcm1fMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5MYXllck5vcm1fMC5zY2FsZQ==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5MYXllck5vcm1fMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5MYXllck5vcm1fMS5zY2FsZQ==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuRGVuc2VfMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuRGVuc2VfMC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuRGVuc2VfMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuRGVuc2VfMS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTGF5ZXJOb3JtXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTGF5ZXJOb3JtXzAuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTGF5ZXJOb3JtXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTGF5ZXJOb3JtXzEuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuRGVuc2VfMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuRGVuc2VfMC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuRGVuc2VfMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuRGVuc2VfMS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTGF5ZXJOb3JtXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTGF5ZXJOb3JtXzAuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTGF5ZXJOb3JtXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTGF5ZXJOb3JtXzEuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuRGVuc2VfMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuRGVuc2VfMC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuRGVuc2VfMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuRGVuc2VfMS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTGF5ZXJOb3JtXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTGF5ZXJOb3JtXzAuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTGF5ZXJOb3JtXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTGF5ZXJOb3JtXzEuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuRGVuc2VfMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuRGVuc2VfMC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuRGVuc2VfMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuRGVuc2VfMS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTGF5ZXJOb3JtXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTGF5ZXJOb3JtXzAuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTGF5ZXJOb3JtXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTGF5ZXJOb3JtXzEuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuRGVuc2VfMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuRGVuc2VfMC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuRGVuc2VfMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuRGVuc2VfMS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTGF5ZXJOb3JtXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTGF5ZXJOb3JtXzAuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTGF5ZXJOb3JtXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTGF5ZXJOb3JtXzEuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuRGVuc2VfMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuRGVuc2VfMC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuRGVuc2VfMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuRGVuc2VfMS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTGF5ZXJOb3JtXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTGF5ZXJOb3JtXzAuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTGF5ZXJOb3JtXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTGF5ZXJOb3JtXzEuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}"}
\ No newline at end of file
diff --git a/checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/array_metadatas/process_0 b/checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/array_metadatas/process_0
new file mode 100644
index 0000000000000000000000000000000000000000..af0fb84132e9bc0aa2fed6958e07a83773a12ae5
--- /dev/null
+++ b/checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/array_metadatas/process_0
@@ -0,0 +1 @@
+{"array_metadatas": [{"array_metadata": {"param_name": "step", "write_shape": [], "chunk_shape": [], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_0.kernel", "write_shape": [1345, 768], "chunk_shape": [1345, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_1.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_1.kernel", "write_shape": [768, 768], "chunk_shape": [768, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_2.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_2.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_3.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_3.kernel", "write_shape": [384, 384], "chunk_shape": [384, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_4.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_4.kernel", "write_shape": [384, 384], "chunk_shape": [384, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_5.bias", "write_shape": [17], "chunk_shape": [17], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_5.kernel", "write_shape": [384, 17], "chunk_shape": [384, 17], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Embed_0.embedding", "write_shape": [18, 384], "chunk_shape": [18, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.LayerNorm_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.LayerNorm_0.scale", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.count", "write_shape": [], "chunk_shape": [], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_0.kernel", "write_shape": [1345, 768], "chunk_shape": [1345, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_1.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_1.kernel", "write_shape": [768, 768], "chunk_shape": [768, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_2.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_2.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_3.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_3.kernel", "write_shape": [384, 384], "chunk_shape": [384, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_4.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_4.kernel", "write_shape": [384, 384], "chunk_shape": [384, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_5.bias", "write_shape": [17], "chunk_shape": [17], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_5.kernel", "write_shape": [384, 17], "chunk_shape": [384, 17], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Embed_0.embedding", "write_shape": [18, 384], "chunk_shape": [18, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.LayerNorm_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.LayerNorm_0.scale", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_0.kernel", "write_shape": [1345, 768], "chunk_shape": [1345, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_1.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_1.kernel", "write_shape": [768, 768], "chunk_shape": [768, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_2.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_2.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_3.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_3.kernel", "write_shape": [384, 384], "chunk_shape": [384, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_4.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_4.kernel", "write_shape": [384, 384], "chunk_shape": [384, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_5.bias", "write_shape": [17], "chunk_shape": [17], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_5.kernel", "write_shape": [384, 17], "chunk_shape": [384, 17], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Embed_0.embedding", "write_shape": [18, 384], "chunk_shape": [18, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.LayerNorm_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.LayerNorm_0.scale", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.1.count", "write_shape": [], "chunk_shape": [], "ext_metadata": null}}]}
\ No newline at end of file
diff --git a/checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/d/021af9ba431a3072f4819480f91b83af b/checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/d/021af9ba431a3072f4819480f91b83af
new file mode 100644
index 0000000000000000000000000000000000000000..40063e0bb0125db03d6fa58bf65a3cb8328f1805
Binary files /dev/null and b/checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/d/021af9ba431a3072f4819480f91b83af differ
diff --git a/checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/manifest.ocdbt b/checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/manifest.ocdbt
new file mode 100644
index 0000000000000000000000000000000000000000..b12a2bab77e77df915993ace51a110585d5eeb2c
Binary files /dev/null and b/checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/manifest.ocdbt differ
diff --git a/checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/ocdbt.process_0/d/0cf9a08a9722f9b8a0b7f007da7c1e92 b/checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/ocdbt.process_0/d/0cf9a08a9722f9b8a0b7f007da7c1e92
new file mode 100644
index 0000000000000000000000000000000000000000..a930731853d772682631bb44ded1cbd4e369f83f
--- /dev/null
+++ b/checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/ocdbt.process_0/d/0cf9a08a9722f9b8a0b7f007da7c1e92
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bbff61a18e9475d72fae302d4748615daf5fc6b87cc0e0a338c96b8a781d6c0f
+size 101199872
diff --git a/checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/ocdbt.process_0/d/1968eb861d84503c0e805cffdd77528a b/checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/ocdbt.process_0/d/1968eb861d84503c0e805cffdd77528a
new file mode 100644
index 0000000000000000000000000000000000000000..04a1c2a5e61500bb83ea8fee15a98bdda281a09b
Binary files /dev/null and b/checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/ocdbt.process_0/d/1968eb861d84503c0e805cffdd77528a differ
diff --git a/checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/ocdbt.process_0/d/aec20934c03229d1bd9651c955e59d84 b/checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/ocdbt.process_0/d/aec20934c03229d1bd9651c955e59d84
new file mode 100644
index 0000000000000000000000000000000000000000..32aeb3b90d60b0a7055cd0d594f6225851c8357d
Binary files /dev/null and b/checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/ocdbt.process_0/d/aec20934c03229d1bd9651c955e59d84 differ
diff --git a/checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/ocdbt.process_0/d/d32b0ea672fe7a9b86b8e62e7c20dbaf b/checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/ocdbt.process_0/d/d32b0ea672fe7a9b86b8e62e7c20dbaf
new file mode 100644
index 0000000000000000000000000000000000000000..4ee8641d9ce591743e3d806da9a6038ed872704c
--- /dev/null
+++ b/checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/ocdbt.process_0/d/d32b0ea672fe7a9b86b8e62e7c20dbaf
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:66e7df58a5ad39030e5631943ffa5d45164b91f283a2b7b34d4265c6bbf08be4
+size 448037
diff --git a/checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/ocdbt.process_0/manifest.ocdbt b/checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/ocdbt.process_0/manifest.ocdbt
new file mode 100644
index 0000000000000000000000000000000000000000..de1c9f551f4ad5ea78d886518e8ad52d1e0d0997
Binary files /dev/null and b/checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/ocdbt.process_0/manifest.ocdbt differ
diff --git a/checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/resume_metadata.json b/checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/resume_metadata.json
new file mode 100644
index 0000000000000000000000000000000000000000..4938279b4e3dbb56ee812a59ca9f0c415381ee7b
--- /dev/null
+++ b/checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/resume_metadata.json
@@ -0,0 +1,68 @@
+{
+ "mode": "offline",
+ "update_step": 1525,
+ "total_gradient_steps_completed": 97600,
+ "wandb_run_id": "6opvce2t",
+ "config_snapshot": {
+ "ENV_NAME": "Craftax-Classic-Symbolic-v1",
+ "USE_OPTIMISTIC_RESETS": false,
+ "OPTIMISTIC_RESET_RATIO": 16,
+ "D_MODEL": 384,
+ "N_HEADS": 8,
+ "N_LAYERS": 6,
+ "D_FF": 768,
+ "OBS_ENCODER_LAYERS": 2,
+ "OBS_ENCODER_WIDTH": 768,
+ "DROPOUT_RATE": 0.1,
+ "PLAN_HORIZON": 32,
+ "DIFFUSION_SCHEDULE": "cosine",
+ "TRAIN_SIGMA": 0.0,
+ "LABEL_SMOOTHING": 0.0,
+ "DIFFUSION_STEPS": 15,
+ "DIFFUSION_STEPS_EVAL": 10,
+ "REMASK_STRATEGY": "rescale",
+ "ETA": 0.5,
+ "USE_LOOP": true,
+ "T_ON": 0.7,
+ "T_OFF": 0.3,
+ "TEMPERATURE": 0.5,
+ "TOP_P": 0.95,
+ "LR": 0.0003,
+ "MAX_GRAD_NORM": 1.0,
+ "LR_WARMUP_FRAMES": "1.048576e8",
+ "NUM_ENVS": 512,
+ "NUM_STEPS": 128,
+ "NUM_MINIBATCHES": 8,
+ "UPDATE_EPOCHS": 8,
+ "NUM_REPEATS": 1,
+ "OFFLINE_TOTAL_TIMESTEPS": 99942400,
+ "COLLECT_TEMPERATURE": 1.0,
+ "RETURN_WEIGHT_CAP": 5.0,
+ "ONLINE_TOTAL_TIMESTEPS": 100000000.0,
+ "DAGGER_BETA_INIT": 1.0,
+ "DAGGER_BETA_FINAL": 0.344,
+ "DAGGER_BUFFER_CYCLES": 1.90735,
+ "VAL_INTERVAL_FRAMES": 1000000.0,
+ "VAL_DIFFUSION_STEPS": 50,
+ "VAL_REPLAN_EVERY": 4,
+ "VAL_STEPS": 256,
+ "COLLECT_NUM_STEPS": 10000000,
+ "COLLECT_NUM_ENVS": 128,
+ "PPO_MODEL_TYPE": "ppo_rnn",
+ "LAYER_SIZE": 512,
+ "EVAL_STEPS": 10000,
+ "EVAL_NUM_ENVS": 32,
+ "SAVE_POLICY": true,
+ "SEED": 42,
+ "USE_WANDB": true,
+ "WANDB_PROJECT": "remdm-craftax",
+ "WANDB_ENTITY": "mathis-weil-university-college-london-ucl-",
+ "MODE": "offline",
+ "JIT": true,
+ "PPO_CHECKPOINT_PATH": "checkpoints/ppo_agents/policies/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M",
+ "NUM_UPDATES": 1525,
+ "LR_WARMUP_STEPS": 1600,
+ "VAL_INTERVAL": 15,
+ "MINIBATCH_SIZE": 6208
+ }
+}
\ No newline at end of file
diff --git a/checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/_CHECKPOINT_METADATA b/checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/_CHECKPOINT_METADATA
new file mode 100644
index 0000000000000000000000000000000000000000..0986df91a3ab4b9643bdb9f4279a648481c330bf
--- /dev/null
+++ b/checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/_CHECKPOINT_METADATA
@@ -0,0 +1 @@
+{"item_handlers": {"default": "orbax.checkpoint._src.handlers.standard_checkpoint_handler.StandardCheckpointHandler"}, "metrics": {}, "performance_metrics": {}, "init_timestamp_nsecs": 1775623858059636986, "commit_timestamp_nsecs": 1775623858516125466, "custom_metadata": {}}
\ No newline at end of file
diff --git a/checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/_METADATA b/checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/_METADATA
new file mode 100644
index 0000000000000000000000000000000000000000..de3d5139d201d649880d3082c460bbb0f7b57ca0
--- /dev/null
+++ b/checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/_METADATA
@@ -0,0 +1 @@
+{"tree_metadata": {"('step',)": {"key_metadata": [{"key": "step", "key_type": 2}], "value_metadata": {"value_type": "scalar", "skip_deserialize": false}}, "('params', 'params', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('params', 'params', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1345, 768]}}, "('params', 'params', 'Dense_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('params', 'params', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 768]}}, "('params', 'params', 'Dense_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'Dense_2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('params', 'params', 'Dense_3', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_3", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'Dense_3', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_3", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 384]}}, "('params', 'params', 'Dense_4', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_4", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'Dense_4', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_4", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 384]}}, "('params', 'params', 'Dense_5', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_5", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [17]}}, "('params', 'params', 'Dense_5', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_5", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 17]}}, "('params', 'params', 'Embed_0', 'embedding')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Embed_0", "key_type": 2}, {"key": "embedding", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [18, 384]}}, "('params', 'params', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('params', 'params', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('params', 'params', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_0', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('params', 'params', 'TransformerBlock_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 768]}}, "('params', 'params', 'TransformerBlock_0', 'Dense_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('params', 'params', 'TransformerBlock_0', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_0', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_0', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_0', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'key', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('params', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'key', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('params', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'out', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'out', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48, 384]}}, "('params', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'query', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('params', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'query', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('params', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'value', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('params', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'value', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('params', 'params', 'TransformerBlock_1', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('params', 'params', 'TransformerBlock_1', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 768]}}, "('params', 'params', 'TransformerBlock_1', 'Dense_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_1', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('params', 'params', 'TransformerBlock_1', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_1', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_1', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_1', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'key', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('params', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'key', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('params', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'out', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'out', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48, 384]}}, "('params', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'query', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('params', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'query', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('params', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'value', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('params', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'value', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('params', 'params', 'TransformerBlock_2', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('params', 'params', 'TransformerBlock_2', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 768]}}, "('params', 'params', 'TransformerBlock_2', 'Dense_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_2', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('params', 'params', 'TransformerBlock_2', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_2', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_2', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_2', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'key', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('params', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'key', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('params', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'out', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'out', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48, 384]}}, "('params', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'query', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('params', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'query', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('params', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'value', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('params', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'value', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('params', 'params', 'TransformerBlock_3', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('params', 'params', 'TransformerBlock_3', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 768]}}, "('params', 'params', 'TransformerBlock_3', 'Dense_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_3', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('params', 'params', 'TransformerBlock_3', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_3', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_3', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_3', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'key', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('params', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'key', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('params', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'out', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'out', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48, 384]}}, "('params', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'query', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('params', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'query', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('params', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'value', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('params', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'value', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('params', 'params', 'TransformerBlock_4', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('params', 'params', 'TransformerBlock_4', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 768]}}, "('params', 'params', 'TransformerBlock_4', 'Dense_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_4', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('params', 'params', 'TransformerBlock_4', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_4', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_4', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_4', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'key', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('params', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'key', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('params', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'out', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'out', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48, 384]}}, "('params', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'query', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('params', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'query', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('params', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'value', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('params', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'value', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('params', 'params', 'TransformerBlock_5', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('params', 'params', 'TransformerBlock_5', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 768]}}, "('params', 'params', 'TransformerBlock_5', 'Dense_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_5', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('params', 'params', 'TransformerBlock_5', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_5', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_5', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_5', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'key', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('params', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'key', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('params', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'out', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'out', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48, 384]}}, "('params', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'query', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('params', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'query', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('params', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'value', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('params', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'value', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '0')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}], "value_metadata": {"value_type": "None", "skip_deserialize": true}}, "('opt_state', '1', '0', 'count')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "count", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": []}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1345, 768]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 768]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_3', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_3", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_3', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_3", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 384]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_4', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_4", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_4', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_4", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 384]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_5', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_5", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [17]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_5', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_5", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 17]}}, "('opt_state', '1', '0', 'mu', 'params', 'Embed_0', 'embedding')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Embed_0", "key_type": 2}, {"key": "embedding", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [18, 384]}}, "('opt_state', '1', '0', 'mu', 'params', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'mu', 'params', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'mu', 'params', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_0', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 768]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_0', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_0', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_0', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_0', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_0', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'key', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'key', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48, 384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'query', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'query', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'value', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'value', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_1', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_1', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 768]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_1', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_1', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_1', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_1', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_1', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_1', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'key', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'key', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48, 384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'query', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'query', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'value', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'value', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_2', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_2', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 768]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_2', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_2', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_2', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_2', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_2', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_2', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'key', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'key', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48, 384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'query', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'query', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'value', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'value', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_3', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_3', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 768]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_3', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_3', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_3', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_3', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_3', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_3', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'key', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'key', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48, 384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'query', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'query', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'value', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'value', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_4', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_4', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 768]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_4', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_4', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_4', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_4', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_4', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_4', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'key', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'key', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48, 384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'query', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'query', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'value', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'value', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_5', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_5', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 768]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_5', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_5', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_5', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_5', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_5', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_5', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'key', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'key', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48, 384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'query', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'query', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'value', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'value', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1345, 768]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 768]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_3', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_3", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_3', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_3", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 384]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_4', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_4", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_4', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_4", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 384]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_5', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_5", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [17]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_5', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_5", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 17]}}, "('opt_state', '1', '0', 'nu', 'params', 'Embed_0', 'embedding')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Embed_0", "key_type": 2}, {"key": "embedding", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [18, 384]}}, "('opt_state', '1', '0', 'nu', 'params', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'nu', 'params', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'nu', 'params', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_0', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 768]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_0', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_0', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_0', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_0', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_0', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'key', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'key', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48, 384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'query', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'query', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'value', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'value', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_1', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_1', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 768]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_1', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_1', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_1', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_1', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_1', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_1', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'key', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'key', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48, 384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'query', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'query', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'value', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'value', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_2', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_2', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 768]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_2', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_2', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_2', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_2', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_2', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_2', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'key', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'key', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48, 384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'query', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'query', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'value', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'value', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_3', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_3', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 768]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_3', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_3', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_3', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_3', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_3', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_3', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'key', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'key', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48, 384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'query', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'query', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'value', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'value', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_4', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_4', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 768]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_4', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_4', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_4', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_4', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_4', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_4', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'key', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'key', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48, 384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'query', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'query', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'value', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'value', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_5', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_5', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 768]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_5', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_5', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_5', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_5', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_5', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_5', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'key', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'key', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48, 384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'query', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'query', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'value', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'value', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '1')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "1", "key_type": 1}], "value_metadata": {"value_type": "None", "skip_deserialize": true}}}, "use_ocdbt": true, "use_zarr3": false, "store_array_data_equal_to_fill_value": true, "custom_metadata": null}
\ No newline at end of file
diff --git a/checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/_sharding b/checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/_sharding
new file mode 100644
index 0000000000000000000000000000000000000000..0018d0fe5602c2ae1f48d1ba9055fa7e1936ae8c
--- /dev/null
+++ b/checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/_sharding
@@ -0,0 +1 @@
+{"b3B0X3N0YXRlLjEuMC5jb3VudA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMi5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMi5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMy5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMy5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfNC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfNC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfNS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfNS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRW1iZWRfMC5lbWJlZGRpbmc=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuTGF5ZXJOb3JtXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuTGF5ZXJOb3JtXzAuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuTGF5ZXJOb3JtXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuTGF5ZXJOb3JtXzEuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMi5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMi5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMy5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMy5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfNC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfNC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfNS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfNS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRW1iZWRfMC5lbWJlZGRpbmc=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuTGF5ZXJOb3JtXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuTGF5ZXJOb3JtXzAuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuTGF5ZXJOb3JtXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuTGF5ZXJOb3JtXzEuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV80LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV80Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV81LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV81Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8wLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8xLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8yLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8yLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8zLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8zLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5FbWJlZF8wLmVtYmVkZGluZw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5MYXllck5vcm1fMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5MYXllck5vcm1fMC5zY2FsZQ==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5MYXllck5vcm1fMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5MYXllck5vcm1fMS5zY2FsZQ==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuRGVuc2VfMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuRGVuc2VfMC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuRGVuc2VfMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuRGVuc2VfMS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTGF5ZXJOb3JtXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTGF5ZXJOb3JtXzAuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTGF5ZXJOb3JtXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTGF5ZXJOb3JtXzEuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuRGVuc2VfMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuRGVuc2VfMC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuRGVuc2VfMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuRGVuc2VfMS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTGF5ZXJOb3JtXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTGF5ZXJOb3JtXzAuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTGF5ZXJOb3JtXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTGF5ZXJOb3JtXzEuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuRGVuc2VfMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuRGVuc2VfMC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuRGVuc2VfMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuRGVuc2VfMS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTGF5ZXJOb3JtXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTGF5ZXJOb3JtXzAuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTGF5ZXJOb3JtXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTGF5ZXJOb3JtXzEuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuRGVuc2VfMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuRGVuc2VfMC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuRGVuc2VfMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuRGVuc2VfMS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTGF5ZXJOb3JtXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTGF5ZXJOb3JtXzAuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTGF5ZXJOb3JtXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTGF5ZXJOb3JtXzEuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuRGVuc2VfMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuRGVuc2VfMC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuRGVuc2VfMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuRGVuc2VfMS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTGF5ZXJOb3JtXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTGF5ZXJOb3JtXzAuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTGF5ZXJOb3JtXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTGF5ZXJOb3JtXzEuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuRGVuc2VfMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuRGVuc2VfMC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuRGVuc2VfMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuRGVuc2VfMS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTGF5ZXJOb3JtXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTGF5ZXJOb3JtXzAuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTGF5ZXJOb3JtXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTGF5ZXJOb3JtXzEuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}"}
\ No newline at end of file
diff --git a/checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/array_metadatas/process_0 b/checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/array_metadatas/process_0
new file mode 100644
index 0000000000000000000000000000000000000000..6dc09f7d71ab2572709c42a52d44ad5b59e2d1ea
--- /dev/null
+++ b/checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/array_metadatas/process_0
@@ -0,0 +1 @@
+{"array_metadatas": [{"array_metadata": {"param_name": "params.params.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_0.kernel", "write_shape": [1345, 768], "chunk_shape": [1345, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_1.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_1.kernel", "write_shape": [768, 768], "chunk_shape": [768, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_2.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_2.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_3.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_3.kernel", "write_shape": [384, 384], "chunk_shape": [384, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_4.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_4.kernel", "write_shape": [384, 384], "chunk_shape": [384, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_5.bias", "write_shape": [17], "chunk_shape": [17], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_5.kernel", "write_shape": [384, 17], "chunk_shape": [384, 17], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Embed_0.embedding", "write_shape": [18, 384], "chunk_shape": [18, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.LayerNorm_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.LayerNorm_0.scale", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.count", "write_shape": [], "chunk_shape": [], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_0.kernel", "write_shape": [1345, 768], "chunk_shape": [1345, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_1.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_1.kernel", "write_shape": [768, 768], "chunk_shape": [768, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_2.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_2.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_3.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_3.kernel", "write_shape": [384, 384], "chunk_shape": [384, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_4.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_4.kernel", "write_shape": [384, 384], "chunk_shape": [384, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_5.bias", "write_shape": [17], "chunk_shape": [17], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_5.kernel", "write_shape": [384, 17], "chunk_shape": [384, 17], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Embed_0.embedding", "write_shape": [18, 384], "chunk_shape": [18, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.LayerNorm_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.LayerNorm_0.scale", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_0.kernel", "write_shape": [1345, 768], "chunk_shape": [1345, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_1.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_1.kernel", "write_shape": [768, 768], "chunk_shape": [768, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_2.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_2.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_3.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_3.kernel", "write_shape": [384, 384], "chunk_shape": [384, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_4.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_4.kernel", "write_shape": [384, 384], "chunk_shape": [384, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_5.bias", "write_shape": [17], "chunk_shape": [17], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_5.kernel", "write_shape": [384, 17], "chunk_shape": [384, 17], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Embed_0.embedding", "write_shape": [18, 384], "chunk_shape": [18, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.LayerNorm_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.LayerNorm_0.scale", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}]}
\ No newline at end of file
diff --git a/checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/d/63ff4b6b75238977cfc360704c224d86 b/checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/d/63ff4b6b75238977cfc360704c224d86
new file mode 100644
index 0000000000000000000000000000000000000000..fae357a26ebb3db90b38a3447d1d9da8e86e4e15
Binary files /dev/null and b/checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/d/63ff4b6b75238977cfc360704c224d86 differ
diff --git a/checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/manifest.ocdbt b/checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/manifest.ocdbt
new file mode 100644
index 0000000000000000000000000000000000000000..22b3d5b45e54da7718c063cb92310f13a16d915e
Binary files /dev/null and b/checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/manifest.ocdbt differ
diff --git a/checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/ocdbt.process_0/d/63a3ac9c870d5f7eb2b74967758ee043 b/checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/ocdbt.process_0/d/63a3ac9c870d5f7eb2b74967758ee043
new file mode 100644
index 0000000000000000000000000000000000000000..804a604796c2a9823bc945bd872bcbe7cbab66a3
Binary files /dev/null and b/checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/ocdbt.process_0/d/63a3ac9c870d5f7eb2b74967758ee043 differ
diff --git a/checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/ocdbt.process_0/d/c3b086748e0ea04233c3638a3994fa30 b/checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/ocdbt.process_0/d/c3b086748e0ea04233c3638a3994fa30
new file mode 100644
index 0000000000000000000000000000000000000000..74c3bcb593785b1fb3634f7100051a528ef9a785
Binary files /dev/null and b/checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/ocdbt.process_0/d/c3b086748e0ea04233c3638a3994fa30 differ
diff --git a/checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/ocdbt.process_0/d/db1925df3ee2d3c92bea0a9878efa5fc b/checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/ocdbt.process_0/d/db1925df3ee2d3c92bea0a9878efa5fc
new file mode 100644
index 0000000000000000000000000000000000000000..c67c899fef975fe77b32faa9470bd7f540359e9e
Binary files /dev/null and b/checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/ocdbt.process_0/d/db1925df3ee2d3c92bea0a9878efa5fc differ
diff --git a/checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/ocdbt.process_0/d/e25dea5d414404e637e55db20175c620 b/checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/ocdbt.process_0/d/e25dea5d414404e637e55db20175c620
new file mode 100644
index 0000000000000000000000000000000000000000..bffb28ab5a843d5164c98e7877cadcfab401f9a4
Binary files /dev/null and b/checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/ocdbt.process_0/d/e25dea5d414404e637e55db20175c620 differ
diff --git a/checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/ocdbt.process_0/d/fd4cdc0c7be44d4f518c6ccca7ad654a b/checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/ocdbt.process_0/d/fd4cdc0c7be44d4f518c6ccca7ad654a
new file mode 100644
index 0000000000000000000000000000000000000000..bf7a6fd370cd12e6dbe30df5d9edbff2e79691eb
--- /dev/null
+++ b/checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/ocdbt.process_0/d/fd4cdc0c7be44d4f518c6ccca7ad654a
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1c27dc63cbdd625c2b62fac311fd37e14406b411ac848847ca4bd4e99f333419
+size 34631680
diff --git a/checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/ocdbt.process_0/manifest.ocdbt b/checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/ocdbt.process_0/manifest.ocdbt
new file mode 100644
index 0000000000000000000000000000000000000000..fd80a355e5475f60ba25b97606b58246ac1d3956
Binary files /dev/null and b/checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/ocdbt.process_0/manifest.ocdbt differ
diff --git a/checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/_CHECKPOINT_METADATA b/checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/_CHECKPOINT_METADATA
new file mode 100644
index 0000000000000000000000000000000000000000..7e036dbbbbb68b3f50bcad514a9ffea587b5f20b
--- /dev/null
+++ b/checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/_CHECKPOINT_METADATA
@@ -0,0 +1 @@
+{"item_handlers": {"default": "orbax.checkpoint._src.handlers.standard_checkpoint_handler.StandardCheckpointHandler"}, "metrics": {}, "performance_metrics": {}, "init_timestamp_nsecs": 1773173340517772966, "commit_timestamp_nsecs": 1773173340998852009, "custom_metadata": {}}
\ No newline at end of file
diff --git a/checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/_METADATA b/checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/_METADATA
new file mode 100644
index 0000000000000000000000000000000000000000..4e46b06ae0446645451fbf801d54c3264da0ff25
--- /dev/null
+++ b/checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/_METADATA
@@ -0,0 +1 @@
+{"tree_metadata": {"('step',)": {"key_metadata": [{"key": "step", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": []}}, "('params', 'params', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'params', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1345, 512]}}, "('params', 'params', 'Dense_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'params', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('params', 'params', 'Dense_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'params', 'Dense_2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('params', 'params', 'Dense_3', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_3", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [17]}}, "('params', 'params', 'Dense_3', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_3", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 17]}}, "('params', 'params', 'Dense_4', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_4", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'params', 'Dense_4', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_4", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('params', 'params', 'Dense_5', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_5", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'params', 'Dense_5', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_5", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('params', 'params', 'Dense_6', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_6", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1]}}, "('params', 'params', 'Dense_6', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_6", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 1]}}, "('params', 'params', 'ScannedRNN_0', 'GRUCell_1', 'hn', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "hn", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'params', 'ScannedRNN_0', 'GRUCell_1', 'hn', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "hn", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('params', 'params', 'ScannedRNN_0', 'GRUCell_1', 'hr', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "hr", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('params', 'params', 'ScannedRNN_0', 'GRUCell_1', 'hz', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "hz", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('params', 'params', 'ScannedRNN_0', 'GRUCell_1', 'in', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "in", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'params', 'ScannedRNN_0', 'GRUCell_1', 'in', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "in", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('params', 'params', 'ScannedRNN_0', 'GRUCell_1', 'ir', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "ir", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'params', 'ScannedRNN_0', 'GRUCell_1', 'ir', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "ir", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('params', 'params', 'ScannedRNN_0', 'GRUCell_1', 'iz', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "iz", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'params', 'ScannedRNN_0', 'GRUCell_1', 'iz', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "iz", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '0')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}], "value_metadata": {"value_type": "None", "skip_deserialize": true}}, "('opt_state', '1', '0', 'count')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "count", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": []}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1345, 512]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_3', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_3", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [17]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_3', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_3", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 17]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_4', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_4", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_4', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_4", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_5', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_5", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_5', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_5", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_6', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_6", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_6', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_6", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 1]}}, "('opt_state', '1', '0', 'mu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'hn', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "hn", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'mu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'hn', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "hn", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'mu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'hr', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "hr", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'mu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'hz', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "hz", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'mu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'in', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "in", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'mu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'in', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "in", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'mu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'ir', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "ir", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'mu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'ir', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "ir", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'mu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'iz', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "iz", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'mu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'iz', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "iz", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1345, 512]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_3', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_3", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [17]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_3', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_3", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 17]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_4', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_4", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_4', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_4", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_5', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_5", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_5', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_5", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_6', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_6", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_6', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_6", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 1]}}, "('opt_state', '1', '0', 'nu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'hn', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "hn", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'nu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'hn', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "hn", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'nu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'hr', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "hr", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'nu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'hz', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "hz", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'nu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'in', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "in", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'nu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'in', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "in", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'nu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'ir', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "ir", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'nu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'ir', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "ir", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'nu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'iz', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "iz", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'nu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'iz', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "iz", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '1', 'count')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "1", "key_type": 1}, {"key": "count", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": []}}}, "use_ocdbt": true, "use_zarr3": false, "store_array_data_equal_to_fill_value": true, "custom_metadata": null}
\ No newline at end of file
diff --git a/checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/_sharding b/checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/_sharding
new file mode 100644
index 0000000000000000000000000000000000000000..429946c94c750adf4803fc2959c02aa1f8cf68a1
--- /dev/null
+++ b/checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/_sharding
@@ -0,0 +1 @@
+{"b3B0X3N0YXRlLjEuMC5jb3VudA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMi5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMi5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMy5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMy5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfNC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfNC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfNS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfNS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfNi5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfNi5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5obi5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5obi5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5oci5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5oei5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5pbi5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5pbi5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5pci5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5pci5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5pei5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5pei5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMi5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMi5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMy5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMy5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfNC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfNC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfNS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfNS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfNi5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfNi5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5obi5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5obi5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5oci5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5oei5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5pbi5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5pbi5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5pci5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5pci5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5pei5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5pei5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMS5jb3VudA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","c3RlcA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV80LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV80Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV81LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV81Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV82LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV82Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8wLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8xLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8yLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8yLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8zLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8zLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5TY2FubmVkUk5OXzAuR1JVQ2VsbF8xLmh6Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5TY2FubmVkUk5OXzAuR1JVQ2VsbF8xLmhuLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5TY2FubmVkUk5OXzAuR1JVQ2VsbF8xLmhuLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5TY2FubmVkUk5OXzAuR1JVQ2VsbF8xLmhyLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5TY2FubmVkUk5OXzAuR1JVQ2VsbF8xLml6LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5TY2FubmVkUk5OXzAuR1JVQ2VsbF8xLml6Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5TY2FubmVkUk5OXzAuR1JVQ2VsbF8xLmluLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5TY2FubmVkUk5OXzAuR1JVQ2VsbF8xLmluLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5TY2FubmVkUk5OXzAuR1JVQ2VsbF8xLmlyLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5TY2FubmVkUk5OXzAuR1JVQ2VsbF8xLmlyLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}"}
\ No newline at end of file
diff --git a/checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/array_metadatas/process_0 b/checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/array_metadatas/process_0
new file mode 100644
index 0000000000000000000000000000000000000000..8791b9fe04d1a3996ac3d3d5d86a46145f675956
--- /dev/null
+++ b/checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/array_metadatas/process_0
@@ -0,0 +1 @@
+{"array_metadatas": [{"array_metadata": {"param_name": "step", "write_shape": [], "chunk_shape": [], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_0.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_0.kernel", "write_shape": [1345, 512], "chunk_shape": [1345, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_1.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_1.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_2.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_2.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_3.bias", "write_shape": [17], "chunk_shape": [17], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_3.kernel", "write_shape": [512, 17], "chunk_shape": [512, 17], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_4.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_4.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_5.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_5.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_6.bias", "write_shape": [1], "chunk_shape": [1], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_6.kernel", "write_shape": [512, 1], "chunk_shape": [512, 1], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.ScannedRNN_0.GRUCell_1.hn.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.ScannedRNN_0.GRUCell_1.hn.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.ScannedRNN_0.GRUCell_1.hr.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.ScannedRNN_0.GRUCell_1.hz.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.ScannedRNN_0.GRUCell_1.in.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.ScannedRNN_0.GRUCell_1.in.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.ScannedRNN_0.GRUCell_1.ir.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.ScannedRNN_0.GRUCell_1.ir.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.ScannedRNN_0.GRUCell_1.iz.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.ScannedRNN_0.GRUCell_1.iz.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.count", "write_shape": [], "chunk_shape": [], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_0.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_0.kernel", "write_shape": [1345, 512], "chunk_shape": [1345, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_1.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_1.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_2.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_2.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_3.bias", "write_shape": [17], "chunk_shape": [17], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_3.kernel", "write_shape": [512, 17], "chunk_shape": [512, 17], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_4.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_4.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_5.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_5.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_6.bias", "write_shape": [1], "chunk_shape": [1], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_6.kernel", "write_shape": [512, 1], "chunk_shape": [512, 1], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.ScannedRNN_0.GRUCell_1.hn.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.ScannedRNN_0.GRUCell_1.hn.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.ScannedRNN_0.GRUCell_1.hr.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.ScannedRNN_0.GRUCell_1.hz.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.ScannedRNN_0.GRUCell_1.in.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.ScannedRNN_0.GRUCell_1.in.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.ScannedRNN_0.GRUCell_1.ir.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.ScannedRNN_0.GRUCell_1.ir.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.ScannedRNN_0.GRUCell_1.iz.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.ScannedRNN_0.GRUCell_1.iz.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_0.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_0.kernel", "write_shape": [1345, 512], "chunk_shape": [1345, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_1.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_1.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_2.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_2.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_3.bias", "write_shape": [17], "chunk_shape": [17], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_3.kernel", "write_shape": [512, 17], "chunk_shape": [512, 17], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_4.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_4.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_5.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_5.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_6.bias", "write_shape": [1], "chunk_shape": [1], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_6.kernel", "write_shape": [512, 1], "chunk_shape": [512, 1], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.ScannedRNN_0.GRUCell_1.hn.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.ScannedRNN_0.GRUCell_1.hn.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.ScannedRNN_0.GRUCell_1.hr.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.ScannedRNN_0.GRUCell_1.hz.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.ScannedRNN_0.GRUCell_1.in.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.ScannedRNN_0.GRUCell_1.in.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.ScannedRNN_0.GRUCell_1.ir.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.ScannedRNN_0.GRUCell_1.ir.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.ScannedRNN_0.GRUCell_1.iz.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.ScannedRNN_0.GRUCell_1.iz.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.1.count", "write_shape": [], "chunk_shape": [], "ext_metadata": null}}]}
\ No newline at end of file
diff --git a/checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/d/3d817e301205eacc425259e9b57de121 b/checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/d/3d817e301205eacc425259e9b57de121
new file mode 100644
index 0000000000000000000000000000000000000000..82f4fad7fcb7fb60d275db8afe17c2b599c4127a
Binary files /dev/null and b/checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/d/3d817e301205eacc425259e9b57de121 differ
diff --git a/checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/manifest.ocdbt b/checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/manifest.ocdbt
new file mode 100644
index 0000000000000000000000000000000000000000..2100c67f1592ac1d222a76a33db3d076ffb9485d
Binary files /dev/null and b/checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/manifest.ocdbt differ
diff --git a/checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/ocdbt.process_0/d/369457b7c6608f1adf28eb88024d6b91 b/checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/ocdbt.process_0/d/369457b7c6608f1adf28eb88024d6b91
new file mode 100644
index 0000000000000000000000000000000000000000..923b34d85e6d1c9fa2ee825866041e90b6720a54
--- /dev/null
+++ b/checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/ocdbt.process_0/d/369457b7c6608f1adf28eb88024d6b91
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:25bead7246dcfff317ab909012529b50f77ce42d2b976cd3d66baf4e3013ef6d
+size 31166464
diff --git a/checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/ocdbt.process_0/d/553952417e25cc3b880b7c458a1b4fa6 b/checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/ocdbt.process_0/d/553952417e25cc3b880b7c458a1b4fa6
new file mode 100644
index 0000000000000000000000000000000000000000..8d0bbc1b3c5cb6e70fe084effae122d5917185a5
Binary files /dev/null and b/checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/ocdbt.process_0/d/553952417e25cc3b880b7c458a1b4fa6 differ
diff --git a/checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/ocdbt.process_0/d/592d44df588422e75968df016db43e91 b/checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/ocdbt.process_0/d/592d44df588422e75968df016db43e91
new file mode 100644
index 0000000000000000000000000000000000000000..7d05015519f121430dced55f611012e5d6052ad7
Binary files /dev/null and b/checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/ocdbt.process_0/d/592d44df588422e75968df016db43e91 differ
diff --git a/checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/ocdbt.process_0/d/6bb110a840637eede93e25f5356236f9 b/checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/ocdbt.process_0/d/6bb110a840637eede93e25f5356236f9
new file mode 100644
index 0000000000000000000000000000000000000000..5320f50b3e7cb2eaeb7867db10bd5e1f09ab96c1
--- /dev/null
+++ b/checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/ocdbt.process_0/d/6bb110a840637eede93e25f5356236f9
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f2e86dd7454e40b6286af72bd0c2c1a1df1914c39f1fef20c0a37f32accff16b
+size 5246976
diff --git a/checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/ocdbt.process_0/d/ca2a63995e2d2ca4dffb0ca4171ab0ee b/checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/ocdbt.process_0/d/ca2a63995e2d2ca4dffb0ca4171ab0ee
new file mode 100644
index 0000000000000000000000000000000000000000..32aeb3b90d60b0a7055cd0d594f6225851c8357d
Binary files /dev/null and b/checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/ocdbt.process_0/d/ca2a63995e2d2ca4dffb0ca4171ab0ee differ
diff --git a/checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/ocdbt.process_0/manifest.ocdbt b/checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/ocdbt.process_0/manifest.ocdbt
new file mode 100644
index 0000000000000000000000000000000000000000..d342ef752f388a0dc1c6ddfe900594f94a6f61b0
Binary files /dev/null and b/checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/ocdbt.process_0/manifest.ocdbt differ
diff --git a/checkpoints/ppo_agents/Craftax-Symbolic-v1-PPO_RNN-1000M/1000000000/_CHECKPOINT_METADATA b/checkpoints/ppo_agents/Craftax-Symbolic-v1-PPO_RNN-1000M/1000000000/_CHECKPOINT_METADATA
new file mode 100644
index 0000000000000000000000000000000000000000..1fc6ed7f6c28d1a129092315199a11b4692de31d
--- /dev/null
+++ b/checkpoints/ppo_agents/Craftax-Symbolic-v1-PPO_RNN-1000M/1000000000/_CHECKPOINT_METADATA
@@ -0,0 +1 @@
+{"item_handlers": {"default": "orbax.checkpoint._src.handlers.standard_checkpoint_handler.StandardCheckpointHandler"}, "metrics": {}, "performance_metrics": {}, "init_timestamp_nsecs": 1774217044031399688, "commit_timestamp_nsecs": 1774217044369893130, "custom_metadata": {}}
\ No newline at end of file
diff --git a/checkpoints/ppo_agents/Craftax-Symbolic-v1-PPO_RNN-1000M/1000000000/default/_METADATA b/checkpoints/ppo_agents/Craftax-Symbolic-v1-PPO_RNN-1000M/1000000000/default/_METADATA
new file mode 100644
index 0000000000000000000000000000000000000000..fdb5a1b7aa2860843779d07bdcb8229313c52d71
--- /dev/null
+++ b/checkpoints/ppo_agents/Craftax-Symbolic-v1-PPO_RNN-1000M/1000000000/default/_METADATA
@@ -0,0 +1 @@
+{"tree_metadata": {"('step',)": {"key_metadata": [{"key": "step", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": []}}, "('params', 'params', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'params', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8268, 512]}}, "('params', 'params', 'Dense_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'params', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('params', 'params', 'Dense_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'params', 'Dense_2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('params', 'params', 'Dense_3', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_3", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [43]}}, "('params', 'params', 'Dense_3', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_3", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 43]}}, "('params', 'params', 'Dense_4', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_4", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'params', 'Dense_4', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_4", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('params', 'params', 'Dense_5', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_5", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'params', 'Dense_5', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_5", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('params', 'params', 'Dense_6', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_6", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1]}}, "('params', 'params', 'Dense_6', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_6", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 1]}}, "('params', 'params', 'ScannedRNN_0', 'GRUCell_1', 'hn', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "hn", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'params', 'ScannedRNN_0', 'GRUCell_1', 'hn', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "hn", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('params', 'params', 'ScannedRNN_0', 'GRUCell_1', 'hr', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "hr", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('params', 'params', 'ScannedRNN_0', 'GRUCell_1', 'hz', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "hz", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('params', 'params', 'ScannedRNN_0', 'GRUCell_1', 'in', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "in", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'params', 'ScannedRNN_0', 'GRUCell_1', 'in', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "in", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('params', 'params', 'ScannedRNN_0', 'GRUCell_1', 'ir', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "ir", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'params', 'ScannedRNN_0', 'GRUCell_1', 'ir', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "ir", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('params', 'params', 'ScannedRNN_0', 'GRUCell_1', 'iz', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "iz", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'params', 'ScannedRNN_0', 'GRUCell_1', 'iz', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "iz", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '0')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}], "value_metadata": {"value_type": "None", "skip_deserialize": true}}, "('opt_state', '1', '0', 'count')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "count", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": []}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8268, 512]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_3', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_3", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [43]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_3', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_3", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 43]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_4', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_4", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_4', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_4", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_5', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_5", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_5', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_5", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_6', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_6", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_6', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_6", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 1]}}, "('opt_state', '1', '0', 'mu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'hn', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "hn", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'mu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'hn', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "hn", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'mu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'hr', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "hr", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'mu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'hz', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "hz", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'mu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'in', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "in", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'mu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'in', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "in", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'mu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'ir', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "ir", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'mu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'ir', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "ir", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'mu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'iz', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "iz", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'mu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'iz', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "iz", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8268, 512]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_3', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_3", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [43]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_3', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_3", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 43]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_4', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_4", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_4', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_4", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_5', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_5", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_5', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_5", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_6', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_6", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_6', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_6", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 1]}}, "('opt_state', '1', '0', 'nu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'hn', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "hn", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'nu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'hn', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "hn", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'nu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'hr', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "hr", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'nu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'hz', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "hz", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'nu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'in', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "in", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'nu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'in', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "in", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'nu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'ir', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "ir", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'nu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'ir', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "ir", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'nu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'iz', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "iz", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'nu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'iz', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "iz", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '1', 'count')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "1", "key_type": 1}, {"key": "count", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": []}}}, "use_ocdbt": true, "use_zarr3": false, "store_array_data_equal_to_fill_value": true, "custom_metadata": null}
\ No newline at end of file
diff --git a/checkpoints/ppo_agents/Craftax-Symbolic-v1-PPO_RNN-1000M/1000000000/default/_sharding b/checkpoints/ppo_agents/Craftax-Symbolic-v1-PPO_RNN-1000M/1000000000/default/_sharding
new file mode 100644
index 0000000000000000000000000000000000000000..429946c94c750adf4803fc2959c02aa1f8cf68a1
--- /dev/null
+++ b/checkpoints/ppo_agents/Craftax-Symbolic-v1-PPO_RNN-1000M/1000000000/default/_sharding
@@ -0,0 +1 @@
+{"b3B0X3N0YXRlLjEuMC5jb3VudA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMi5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMi5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMy5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMy5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfNC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfNC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfNS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfNS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfNi5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfNi5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5obi5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5obi5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5oci5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5oei5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5pbi5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5pbi5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5pci5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5pci5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5pei5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5pei5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMi5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMi5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMy5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMy5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfNC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfNC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfNS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfNS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfNi5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfNi5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5obi5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5obi5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5oci5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5oei5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5pbi5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5pbi5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5pci5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5pci5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5pei5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5pei5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMS5jb3VudA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","c3RlcA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV80LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV80Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV81LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV81Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV82LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV82Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8wLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8xLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8yLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8yLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8zLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8zLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5TY2FubmVkUk5OXzAuR1JVQ2VsbF8xLmh6Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5TY2FubmVkUk5OXzAuR1JVQ2VsbF8xLmhuLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5TY2FubmVkUk5OXzAuR1JVQ2VsbF8xLmhuLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5TY2FubmVkUk5OXzAuR1JVQ2VsbF8xLmhyLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5TY2FubmVkUk5OXzAuR1JVQ2VsbF8xLml6LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5TY2FubmVkUk5OXzAuR1JVQ2VsbF8xLml6Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5TY2FubmVkUk5OXzAuR1JVQ2VsbF8xLmluLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5TY2FubmVkUk5OXzAuR1JVQ2VsbF8xLmluLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5TY2FubmVkUk5OXzAuR1JVQ2VsbF8xLmlyLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5TY2FubmVkUk5OXzAuR1JVQ2VsbF8xLmlyLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}"}
\ No newline at end of file
diff --git a/checkpoints/ppo_agents/Craftax-Symbolic-v1-PPO_RNN-1000M/1000000000/default/array_metadatas/process_0 b/checkpoints/ppo_agents/Craftax-Symbolic-v1-PPO_RNN-1000M/1000000000/default/array_metadatas/process_0
new file mode 100644
index 0000000000000000000000000000000000000000..d9161e15d6b3e9297e6700b5e515328ea8a34296
--- /dev/null
+++ b/checkpoints/ppo_agents/Craftax-Symbolic-v1-PPO_RNN-1000M/1000000000/default/array_metadatas/process_0
@@ -0,0 +1 @@
+{"array_metadatas": [{"array_metadata": {"param_name": "step", "write_shape": [], "chunk_shape": [], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_0.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_0.kernel", "write_shape": [8268, 512], "chunk_shape": [8268, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_1.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_1.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_2.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_2.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_3.bias", "write_shape": [43], "chunk_shape": [43], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_3.kernel", "write_shape": [512, 43], "chunk_shape": [512, 43], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_4.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_4.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_5.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_5.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_6.bias", "write_shape": [1], "chunk_shape": [1], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_6.kernel", "write_shape": [512, 1], "chunk_shape": [512, 1], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.ScannedRNN_0.GRUCell_1.hn.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.ScannedRNN_0.GRUCell_1.hn.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.ScannedRNN_0.GRUCell_1.hr.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.ScannedRNN_0.GRUCell_1.hz.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.ScannedRNN_0.GRUCell_1.in.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.ScannedRNN_0.GRUCell_1.in.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.ScannedRNN_0.GRUCell_1.ir.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.ScannedRNN_0.GRUCell_1.ir.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.ScannedRNN_0.GRUCell_1.iz.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.ScannedRNN_0.GRUCell_1.iz.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.count", "write_shape": [], "chunk_shape": [], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_0.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_0.kernel", "write_shape": [8268, 512], "chunk_shape": [8268, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_1.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_1.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_2.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_2.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_3.bias", "write_shape": [43], "chunk_shape": [43], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_3.kernel", "write_shape": [512, 43], "chunk_shape": [512, 43], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_4.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_4.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_5.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_5.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_6.bias", "write_shape": [1], "chunk_shape": [1], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_6.kernel", "write_shape": [512, 1], "chunk_shape": [512, 1], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.ScannedRNN_0.GRUCell_1.hn.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.ScannedRNN_0.GRUCell_1.hn.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.ScannedRNN_0.GRUCell_1.hr.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.ScannedRNN_0.GRUCell_1.hz.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.ScannedRNN_0.GRUCell_1.in.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.ScannedRNN_0.GRUCell_1.in.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.ScannedRNN_0.GRUCell_1.ir.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.ScannedRNN_0.GRUCell_1.ir.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.ScannedRNN_0.GRUCell_1.iz.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.ScannedRNN_0.GRUCell_1.iz.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_0.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_0.kernel", "write_shape": [8268, 512], "chunk_shape": [8268, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_1.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_1.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_2.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_2.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_3.bias", "write_shape": [43], "chunk_shape": [43], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_3.kernel", "write_shape": [512, 43], "chunk_shape": [512, 43], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_4.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_4.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_5.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_5.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_6.bias", "write_shape": [1], "chunk_shape": [1], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_6.kernel", "write_shape": [512, 1], "chunk_shape": [512, 1], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.ScannedRNN_0.GRUCell_1.hn.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.ScannedRNN_0.GRUCell_1.hn.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.ScannedRNN_0.GRUCell_1.hr.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.ScannedRNN_0.GRUCell_1.hz.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.ScannedRNN_0.GRUCell_1.in.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.ScannedRNN_0.GRUCell_1.in.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.ScannedRNN_0.GRUCell_1.ir.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.ScannedRNN_0.GRUCell_1.ir.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.ScannedRNN_0.GRUCell_1.iz.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.ScannedRNN_0.GRUCell_1.iz.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.1.count", "write_shape": [], "chunk_shape": [], "ext_metadata": null}}]}
\ No newline at end of file
diff --git a/checkpoints/ppo_agents/Craftax-Symbolic-v1-PPO_RNN-1000M/1000000000/default/d/4c4db8700ec36cc1416c034bfbe0f71a b/checkpoints/ppo_agents/Craftax-Symbolic-v1-PPO_RNN-1000M/1000000000/default/d/4c4db8700ec36cc1416c034bfbe0f71a
new file mode 100644
index 0000000000000000000000000000000000000000..fb6ebf9f33d1d8aa4170f17f6aee74ee7814bd7b
Binary files /dev/null and b/checkpoints/ppo_agents/Craftax-Symbolic-v1-PPO_RNN-1000M/1000000000/default/d/4c4db8700ec36cc1416c034bfbe0f71a differ
diff --git a/checkpoints/ppo_agents/Craftax-Symbolic-v1-PPO_RNN-1000M/1000000000/default/manifest.ocdbt b/checkpoints/ppo_agents/Craftax-Symbolic-v1-PPO_RNN-1000M/1000000000/default/manifest.ocdbt
new file mode 100644
index 0000000000000000000000000000000000000000..02fec15038b36ed47875f7f6884eaee855ab03cd
Binary files /dev/null and b/checkpoints/ppo_agents/Craftax-Symbolic-v1-PPO_RNN-1000M/1000000000/default/manifest.ocdbt differ
diff --git a/checkpoints/ppo_agents/Craftax-Symbolic-v1-PPO_RNN-1000M/1000000000/default/ocdbt.process_0/d/0f5cd6bf71d63c0de8e8764ba0de8349 b/checkpoints/ppo_agents/Craftax-Symbolic-v1-PPO_RNN-1000M/1000000000/default/ocdbt.process_0/d/0f5cd6bf71d63c0de8e8764ba0de8349
new file mode 100644
index 0000000000000000000000000000000000000000..d34c873b8c69918f68938c5d2826dacc4a096b65
Binary files /dev/null and b/checkpoints/ppo_agents/Craftax-Symbolic-v1-PPO_RNN-1000M/1000000000/default/ocdbt.process_0/d/0f5cd6bf71d63c0de8e8764ba0de8349 differ
diff --git a/checkpoints/ppo_agents/Craftax-Symbolic-v1-PPO_RNN-1000M/1000000000/default/ocdbt.process_0/d/3258293b820e7cc97e3af7b3d69255b6 b/checkpoints/ppo_agents/Craftax-Symbolic-v1-PPO_RNN-1000M/1000000000/default/ocdbt.process_0/d/3258293b820e7cc97e3af7b3d69255b6
new file mode 100644
index 0000000000000000000000000000000000000000..a00b64eed77a32cfa481a8bed9767b8ef3754688
Binary files /dev/null and b/checkpoints/ppo_agents/Craftax-Symbolic-v1-PPO_RNN-1000M/1000000000/default/ocdbt.process_0/d/3258293b820e7cc97e3af7b3d69255b6 differ
diff --git a/checkpoints/ppo_agents/Craftax-Symbolic-v1-PPO_RNN-1000M/1000000000/default/ocdbt.process_0/d/3566d5252a486c26c607b2637742f32f b/checkpoints/ppo_agents/Craftax-Symbolic-v1-PPO_RNN-1000M/1000000000/default/ocdbt.process_0/d/3566d5252a486c26c607b2637742f32f
new file mode 100644
index 0000000000000000000000000000000000000000..2df53ebd75e606fb7c7b6d9d9146fa05f586ee6b
Binary files /dev/null and b/checkpoints/ppo_agents/Craftax-Symbolic-v1-PPO_RNN-1000M/1000000000/default/ocdbt.process_0/d/3566d5252a486c26c607b2637742f32f differ
diff --git a/checkpoints/ppo_agents/Craftax-Symbolic-v1-PPO_RNN-1000M/1000000000/default/ocdbt.process_0/d/e5a4020f50167115120fe5dac41c20fb b/checkpoints/ppo_agents/Craftax-Symbolic-v1-PPO_RNN-1000M/1000000000/default/ocdbt.process_0/d/e5a4020f50167115120fe5dac41c20fb
new file mode 100644
index 0000000000000000000000000000000000000000..083f9210c8836dbcfd9fe44f005a306d268a5b6c
--- /dev/null
+++ b/checkpoints/ppo_agents/Craftax-Symbolic-v1-PPO_RNN-1000M/1000000000/default/ocdbt.process_0/d/e5a4020f50167115120fe5dac41c20fb
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fba36a05ef8ebbfda38b7f622137b5e9c9bd56cf6d6a534c4acb34e1897082f1
+size 51884032
diff --git a/checkpoints/ppo_agents/Craftax-Symbolic-v1-PPO_RNN-1000M/1000000000/default/ocdbt.process_0/manifest.ocdbt b/checkpoints/ppo_agents/Craftax-Symbolic-v1-PPO_RNN-1000M/1000000000/default/ocdbt.process_0/manifest.ocdbt
new file mode 100644
index 0000000000000000000000000000000000000000..ca6db89a1a9e9a6e8b93af317e0a23ec59be1d36
Binary files /dev/null and b/checkpoints/ppo_agents/Craftax-Symbolic-v1-PPO_RNN-1000M/1000000000/default/ocdbt.process_0/manifest.ocdbt differ
diff --git a/configs/classic_exp_a_beta_fix.yaml b/configs/classic_exp_a_beta_fix.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..5fc8cc11e14e988cb387b642798cb2fc232e2cbf
--- /dev/null
+++ b/configs/classic_exp_a_beta_fix.yaml
@@ -0,0 +1,90 @@
+# =============================================================================
+# Experiment A — Beta decay fix only (isolate data quality)
+# Target hardware: NVIDIA RTX 3090 Ti (24 GB)
+# Slows dagger_beta_decay so the expert stays present roughly 2x longer.
+# Everything else identical to the baseline.
+# =============================================================================
+
+# ── Environment ──────────────────────────────────────────────────────────────
+env_name: "Craftax-Classic-Symbolic-v1"
+use_optimistic_resets: false
+optimistic_reset_ratio: 16
+
+# ── Transformer architecture ─────────────────────────────────────────────────
+d_model: 256
+n_heads: 4
+n_layers: 4
+d_ff: 512
+obs_encoder_layers: 2
+obs_encoder_width: 512
+dropout_rate: 0.1
+
+# ── Diffusion ────────────────────────────────────────────────────────────────
+plan_horizon: 32
+diffusion_schedule: "cosine"
+train_sigma: 0.0
+label_smoothing: 0.0
+
+# ── Reverse sampling ─────────────────────────────────────────────────────────
+diffusion_steps: 15
+diffusion_steps_eval: 10
+remask_strategy: "rescale"
+eta: 0.5
+use_loop: true
+t_on: 0.7
+t_off: 0.3
+temperature: 0.5
+top_p: 0.95
+
+# ── Optimisation ─────────────────────────────────────────────────────────────
+lr: 3.0e-4
+max_grad_norm: 1.0
+lr_warmup_steps: 200
+
+# ── Rollout ──────────────────────────────────────────────────────────────────
+num_envs: 4096
+num_steps: 128
+num_minibatches: 16
+update_epochs: 8
+num_repeats: 1
+
+# ── Offline training ─────────────────────────────────────────────────────────
+offline_total_timesteps: 1.0e+8
+collect_temperature: 1.0
+return_weight_cap: 5.0
+
+# ── Online DAgger training ───────────────────────────────────────────────────
+online_num_updates: 1500 # val plateaus by step 800
+dagger_beta_init: 1.0
+dagger_beta_decay: 0.9993 # beta > 0.5 at step 990, > 0.1 at step 3290
+dagger_buffer_max: 1000000
+
+# ── Validation rollouts ──────────────────────────────────────────────────────
+val_interval: 30
+val_diffusion_steps: 50
+val_replan_every: 4
+val_steps: 256
+
+# ── Data collection ──────────────────────────────────────────────────────────
+collect_num_steps: 10000000
+collect_num_envs: 128
+ppo_model_type: "ppo_rnn"
+layer_size: 512
+
+# ── Inference ────────────────────────────────────────────────────────────────
+eval_steps: 10000
+eval_num_envs: 32
+
+# ── Checkpointing ────────────────────────────────────────────────────────────
+save_policy: true
+
+# ── Resume ───────────────────────────────────────────────────────────────────
+resume_checkpoint_path: null
+resume_wandb_run_id: null
+resume_step: null
+
+# ── Logging ──────────────────────────────────────────────────────────────────
+seed: 42
+use_wandb: true
+wandb_project: "remdm-craftax"
+wandb_entity: "mathis-weil-university-college-london-ucl-"
diff --git a/configs/classic_exp_b_beta_big_model.yaml b/configs/classic_exp_b_beta_big_model.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..cb6c0ea47cc28faf6478f4e869376855f3e77e64
--- /dev/null
+++ b/configs/classic_exp_b_beta_big_model.yaml
@@ -0,0 +1,90 @@
+# =============================================================================
+# Experiment B — Beta fix + bigger model (isolate model capacity)
+# Target hardware: NVIDIA RTX 3090 Ti (24 GB)
+# Cumulative with Experiment A: same beta decay fix, plus a 3.5x larger
+# transformer (~18 GB peak VRAM).
+# =============================================================================
+
+# ── Environment ──────────────────────────────────────────────────────────────
+env_name: "Craftax-Classic-Symbolic-v1"
+use_optimistic_resets: false
+optimistic_reset_ratio: 16
+
+# ── Transformer architecture ─────────────────────────────────────────────────
+d_model: 384
+n_heads: 8
+n_layers: 6
+d_ff: 768
+obs_encoder_layers: 2
+obs_encoder_width: 768
+dropout_rate: 0.1
+
+# ── Diffusion ────────────────────────────────────────────────────────────────
+plan_horizon: 32
+diffusion_schedule: "cosine"
+train_sigma: 0.0
+label_smoothing: 0.0
+
+# ── Reverse sampling ─────────────────────────────────────────────────────────
+diffusion_steps: 15
+diffusion_steps_eval: 10
+remask_strategy: "rescale"
+eta: 0.5
+use_loop: true
+t_on: 0.7
+t_off: 0.3
+temperature: 0.5
+top_p: 0.95
+
+# ── Optimisation ─────────────────────────────────────────────────────────────
+lr: 3.0e-4
+max_grad_norm: 1.0
+lr_warmup_steps: 200
+
+# ── Rollout ──────────────────────────────────────────────────────────────────
+num_envs: 4096
+num_steps: 128
+num_minibatches: 8 # minibatch 2048, better gradients than Exp A
+update_epochs: 8
+num_repeats: 1
+
+# ── Offline training ─────────────────────────────────────────────────────────
+offline_total_timesteps: 1.0e+8
+collect_temperature: 1.0
+return_weight_cap: 5.0
+
+# ── Online DAgger training ───────────────────────────────────────────────────
+online_num_updates: 1500
+dagger_beta_init: 1.0
+dagger_beta_decay: 0.9993
+dagger_buffer_max: 1000000
+
+# ── Validation rollouts ──────────────────────────────────────────────────────
+val_interval: 30
+val_diffusion_steps: 50
+val_replan_every: 4
+val_steps: 256
+
+# ── Data collection ──────────────────────────────────────────────────────────
+collect_num_steps: 10000000
+collect_num_envs: 128
+ppo_model_type: "ppo_rnn"
+layer_size: 512
+
+# ── Inference ────────────────────────────────────────────────────────────────
+eval_steps: 10000
+eval_num_envs: 32
+
+# ── Checkpointing ────────────────────────────────────────────────────────────
+save_policy: true
+
+# ── Resume ───────────────────────────────────────────────────────────────────
+resume_checkpoint_path: null
+resume_wandb_run_id: null
+resume_step: null
+
+# ── Logging ──────────────────────────────────────────────────────────────────
+seed: 42
+use_wandb: true
+wandb_project: "remdm-craftax"
+wandb_entity: "mathis-weil-university-college-london-ucl-"
diff --git a/configs/classic_exp_c_full_recipe.yaml b/configs/classic_exp_c_full_recipe.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..6565bcf16ceccf225abc786ab6a0c2cb3adc7a35
--- /dev/null
+++ b/configs/classic_exp_c_full_recipe.yaml
@@ -0,0 +1,90 @@
+# =============================================================================
+# Experiment C — Full recipe (beta + big model + training dynamics)
+# Target hardware: NVIDIA RTX 4090 (24 GB)
+# Cumulative with Experiments A + B: also bumps diffusion_steps, lowers
+# temperature, raises lr, and tightens validation to match training settings.
+# =============================================================================
+
+# ── Environment ──────────────────────────────────────────────────────────────
+env_name: "Craftax-Classic-Symbolic-v1"
+use_optimistic_resets: false
+optimistic_reset_ratio: 16
+
+# ── Transformer architecture ─────────────────────────────────────────────────
+d_model: 384
+n_heads: 8
+n_layers: 6
+d_ff: 768
+obs_encoder_layers: 2
+obs_encoder_width: 768
+dropout_rate: 0.1
+
+# ── Diffusion ────────────────────────────────────────────────────────────────
+plan_horizon: 32
+diffusion_schedule: "cosine"
+train_sigma: 0.0
+label_smoothing: 0.0
+
+# ── Reverse sampling ─────────────────────────────────────────────────────────
+diffusion_steps: 25 # more denoising for cleaner plans
+diffusion_steps_eval: 10
+remask_strategy: "rescale"
+eta: 0.5
+use_loop: true
+t_on: 0.7
+t_off: 0.3
+temperature: 0.3 # sharper sampling than baseline
+top_p: 0.95
+
+# ── Optimisation ─────────────────────────────────────────────────────────────
+lr: 5.0e-4 # scale with 2x effective batch
+max_grad_norm: 1.0
+lr_warmup_steps: 300 # stabilize higher LR + bigger model
+
+# ── Rollout ──────────────────────────────────────────────────────────────────
+num_envs: 448
+num_steps: 128
+num_minibatches: 8 # minibatch 2048
+update_epochs: 8
+num_repeats: 1
+
+# ── Offline training ─────────────────────────────────────────────────────────
+offline_total_timesteps: 1.0e+8
+collect_temperature: 1.0
+return_weight_cap: 5.0
+
+# ── Online DAgger training ───────────────────────────────────────────────────
+online_total_timesteps: 1.0e+8
+dagger_beta_init: 1.0
+dagger_beta_decay: 0.9993
+dagger_buffer_max: 1000000
+
+# ── Validation rollouts ──────────────────────────────────────────────────────
+val_interval: 30
+val_diffusion_steps: 25 # match diffusion_steps
+val_replan_every: 2
+val_steps: 256
+
+# ── Data collection ──────────────────────────────────────────────────────────
+collect_num_steps: 10000000
+collect_num_envs: 128
+ppo_model_type: "ppo_rnn"
+layer_size: 512
+
+# ── Inference ────────────────────────────────────────────────────────────────
+eval_steps: 10000
+eval_num_envs: 32
+
+# ── Checkpointing ────────────────────────────────────────────────────────────
+save_policy: true
+
+# ── Resume ───────────────────────────────────────────────────────────────────
+resume_checkpoint_path: null
+resume_wandb_run_id: null
+resume_step: null
+
+# ── Logging ──────────────────────────────────────────────────────────────────
+seed: 42
+use_wandb: true
+wandb_project: "remdm-craftax"
+wandb_entity: "mathis-weil-university-college-london-ucl-"
diff --git a/configs/classic_exp_d_100K_model.yaml b/configs/classic_exp_d_100K_model.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..7ee0c12c9b2814352ece10c5c3005262a1939837
--- /dev/null
+++ b/configs/classic_exp_d_100K_model.yaml
@@ -0,0 +1,84 @@
+# =============================================================================
+# Final Full Craftax DAgger — UCL RTX 4090 24 GB - EXP D - 0.1M model
+# Big transformer + full sampling recipe + slow beta decay. Produces the Full
+# Craftax DAgger checkpoint used by the RL fine-tuning ablation suite.
+# =============================================================================
+
+# ── Environment ──────────────────────────────────────────────────────────────
+env_name: "Craftax-Classic-Symbolic-v1"
+use_optimistic_resets: false
+optimistic_reset_ratio: 16
+
+# ── Transformer architecture ─────────────────────────────────────────────────
+d_model: 32
+n_heads: 2
+n_layers: 1
+d_ff: 64
+obs_encoder_layers: 1
+obs_encoder_width: 64
+dropout_rate: 0.1
+
+# ── Diffusion ────────────────────────────────────────────────────────────────
+plan_horizon: 32
+diffusion_schedule: "cosine"
+train_sigma: 0.0
+label_smoothing: 0.0
+
+# ── Reverse sampling ─────────────────────────────────────────────────────────
+diffusion_steps: 25
+diffusion_steps_eval: 10
+remask_strategy: "rescale"
+eta: 0.5
+use_loop: true
+t_on: 0.7
+t_off: 0.3
+temperature: 0.3
+top_p: 0.95
+
+# ── Optimisation ─────────────────────────────────────────────────────────────
+lr: 5.0e-4
+max_grad_norm: 1.0
+lr_warmup_frames: 7.86432e7 # = 300 update steps at this hardware (300 * 2048 * 128)
+
+# ── Rollout ──────────────────────────────────────────────────────────────────
+num_envs: 3072
+num_steps: 128
+num_minibatches: 8
+update_epochs: 8
+num_repeats: 1
+
+# ── Offline training ─────────────────────────────────────────────────────────
+offline_total_timesteps: 2.0e+8
+collect_temperature: 1.0
+return_weight_cap: 5.0
+
+# ── Online DAgger training ───────────────────────────────────────────────────
+online_total_timesteps: 2.0e+8
+dagger_beta_init: 1.0
+dagger_beta_final: 0.385 # 0.9995^1907 at this hardware; auto-resolves per fpu
+dagger_buffer_cycles: 0.76294 # ~200K samples on UCL = ~0.76 update cycles of history
+
+# ── Validation rollouts ──────────────────────────────────────────────────────
+val_interval_frames: 1.0e+6
+val_diffusion_steps: 25
+val_replan_every: 2
+val_steps: 256
+
+# ── Data collection ──────────────────────────────────────────────────────────
+collect_num_steps: 10000000
+collect_num_envs: 128
+ppo_model_type: "ppo_rnn"
+layer_size: 512
+
+# ── Inference ────────────────────────────────────────────────────────────────
+eval_steps: 10000
+eval_num_envs: 32
+
+# ── Checkpointing ────────────────────────────────────────────────────────────
+save_policy: true
+
+# ── Logging ──────────────────────────────────────────────────────────────────
+seed: 42
+use_wandb: true
+wandb_project: "remdm-craftax"
+wandb_entity: "mathis-weil-university-college-london-ucl-"
diff --git a/configs/classic_exp_d_250K_model.yaml b/configs/classic_exp_d_250K_model.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..3d96a18a41bc13bb472747769c9d2d109628ef0a
--- /dev/null
+++ b/configs/classic_exp_d_250K_model.yaml
@@ -0,0 +1,84 @@
+# =============================================================================
+# Final Full Craftax DAgger — UCL RTX 4090 24 GB - EXP D - 0.25M model
+# Big transformer + full sampling recipe + slow beta decay. Produces the Full
+# Craftax DAgger checkpoint used by the RL fine-tuning ablation suite.
+# =============================================================================
+
+# ── Environment ──────────────────────────────────────────────────────────────
+env_name: "Craftax-Classic-Symbolic-v1"
+use_optimistic_resets: false
+optimistic_reset_ratio: 16
+
+# ── Transformer architecture ─────────────────────────────────────────────────
+d_model: 64
+n_heads: 2
+n_layers: 2
+d_ff: 128
+obs_encoder_layers: 1
+obs_encoder_width: 128
+dropout_rate: 0.1
+
+# ── Diffusion ────────────────────────────────────────────────────────────────
+plan_horizon: 32
+diffusion_schedule: "cosine"
+train_sigma: 0.0
+label_smoothing: 0.0
+
+# ── Reverse sampling ─────────────────────────────────────────────────────────
+diffusion_steps: 25
+diffusion_steps_eval: 10
+remask_strategy: "rescale"
+eta: 0.5
+use_loop: true
+t_on: 0.7
+t_off: 0.3
+temperature: 0.3
+top_p: 0.95
+
+# ── Optimisation ─────────────────────────────────────────────────────────────
+lr: 5.0e-4
+max_grad_norm: 1.0
+lr_warmup_frames: 7.86432e7 # = 300 update steps at this hardware (300 * 2048 * 128)
+
+# ── Rollout ──────────────────────────────────────────────────────────────────
+num_envs: 6144
+num_steps: 128
+num_minibatches: 8
+update_epochs: 8
+num_repeats: 1
+
+# ── Offline training ─────────────────────────────────────────────────────────
+offline_total_timesteps: 2.0e+8
+collect_temperature: 1.0
+return_weight_cap: 5.0
+
+# ── Online DAgger training ───────────────────────────────────────────────────
+online_total_timesteps: 2.0e+8
+dagger_beta_init: 1.0
+dagger_beta_final: 0.385 # 0.9995^1907 at this hardware; auto-resolves per fpu
+dagger_buffer_cycles: 0.76294 # ~200K samples on UCL = ~0.76 update cycles of history
+
+# ── Validation rollouts ──────────────────────────────────────────────────────
+val_interval_frames: 1.0e+6
+val_diffusion_steps: 25
+val_replan_every: 2
+val_steps: 256
+
+# ── Data collection ──────────────────────────────────────────────────────────
+collect_num_steps: 10000000
+collect_num_envs: 128
+ppo_model_type: "ppo_rnn"
+layer_size: 512
+
+# ── Inference ────────────────────────────────────────────────────────────────
+eval_steps: 10000
+eval_num_envs: 32
+
+# ── Checkpointing ────────────────────────────────────────────────────────────
+save_policy: true
+
+# ── Logging ──────────────────────────────────────────────────────────────────
+seed: 42
+use_wandb: true
+wandb_project: "remdm-craftax"
+wandb_entity: "mathis-weil-university-college-london-ucl-"
diff --git a/configs/classic_exp_d_3M_model.yaml b/configs/classic_exp_d_3M_model.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..feb140804795139458315344007014cc5566db65
--- /dev/null
+++ b/configs/classic_exp_d_3M_model.yaml
@@ -0,0 +1,84 @@
+# =============================================================================
+# Final Craftax Classic DAgger — UCL RTX 3090 Ti 24 GB - EXP D - 3.2M model
+# Big transformer + slow beta decay. Produces the Classic DAgger checkpoint
+# used by the RL fine-tuning ablation suite.
+# =============================================================================
+
+# ── Environment ──────────────────────────────────────────────────────────────
+env_name: "Craftax-Classic-Symbolic-v1"
+use_optimistic_resets: false
+optimistic_reset_ratio: 16
+
+# ── Transformer architecture ─────────────────────────────────────────────────
+d_model: 256
+n_heads: 4
+n_layers: 4
+d_ff: 512
+obs_encoder_layers: 2
+obs_encoder_width: 512
+dropout_rate: 0.1
+
+# ── Diffusion ────────────────────────────────────────────────────────────────
+plan_horizon: 32
+diffusion_schedule: "cosine"
+train_sigma: 0.0
+label_smoothing: 0.0
+
+# ── Reverse sampling ─────────────────────────────────────────────────────────
+diffusion_steps: 15
+diffusion_steps_eval: 10
+remask_strategy: "rescale"
+eta: 0.5
+use_loop: true
+t_on: 0.7
+t_off: 0.3
+temperature: 0.5
+top_p: 0.95
+
+# ── Optimisation ─────────────────────────────────────────────────────────────
+lr: 3.0e-4
+max_grad_norm: 1.0
+lr_warmup_frames: 1.048576e8 # = 200 update steps at this hardware (200 * 4096 * 128)
+
+# ── Rollout ──────────────────────────────────────────────────────────────────
+num_envs: 1120
+num_steps: 128
+num_minibatches: 8
+update_epochs: 8
+num_repeats: 1
+
+# ── Offline training ─────────────────────────────────────────────────────────
+offline_total_timesteps: 2.0e+8
+collect_temperature: 1.0
+return_weight_cap: 5.0
+
+# ── Online DAgger training ───────────────────────────────────────────────────
+online_total_timesteps: 2.0e+8
+dagger_beta_init: 1.0
+dagger_beta_final: 0.344 # 0.9993^1525 at this hardware; auto-resolves per fpu
+dagger_buffer_cycles: 1.90735 # ~1M samples on UCL = ~1.91 update cycles of history
+
+# ── Validation rollouts ──────────────────────────────────────────────────────
+val_interval_frames: 1.0e+6
+val_diffusion_steps: 50
+val_replan_every: 4
+val_steps: 256
+
+# ── Data collection ──────────────────────────────────────────────────────────
+collect_num_steps: 10000000
+collect_num_envs: 128
+ppo_model_type: "ppo_rnn"
+layer_size: 512
+
+# ── Inference ────────────────────────────────────────────────────────────────
+eval_steps: 10000
+eval_num_envs: 32
+
+# ── Checkpointing ────────────────────────────────────────────────────────────
+save_policy: true
+
+# ── Logging ──────────────────────────────────────────────────────────────────
+seed: 42
+use_wandb: true
+wandb_project: "remdm-craftax"
+wandb_entity: "mathis-weil-university-college-london-ucl-"
diff --git a/configs/classic_exp_d_850K_model.yaml b/configs/classic_exp_d_850K_model.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..9e1c2c49c3faccefb98985734a7c0a2e2a1459c9
--- /dev/null
+++ b/configs/classic_exp_d_850K_model.yaml
@@ -0,0 +1,84 @@
+# =============================================================================
+# Final Full Craftax DAgger — UCL RTX 4090 24 GB - EXP D - 0.85M model
+# Big transformer + full sampling recipe + slow beta decay. Produces the Full
+# Craftax DAgger checkpoint used by the RL fine-tuning ablation suite.
+# =============================================================================
+
+# ── Environment ──────────────────────────────────────────────────────────────
+env_name: "Craftax-Classic-Symbolic-v1"
+use_optimistic_resets: false
+optimistic_reset_ratio: 16
+
+# ── Transformer architecture ─────────────────────────────────────────────────
+d_model: 128
+n_heads: 4
+n_layers: 3
+d_ff: 256
+obs_encoder_layers: 2
+obs_encoder_width: 256
+dropout_rate: 0.1
+
+# ── Diffusion ────────────────────────────────────────────────────────────────
+plan_horizon: 32
+diffusion_schedule: "cosine"
+train_sigma: 0.0
+label_smoothing: 0.0
+
+# ── Reverse sampling ─────────────────────────────────────────────────────────
+diffusion_steps: 25
+diffusion_steps_eval: 10
+remask_strategy: "rescale"
+eta: 0.5
+use_loop: true
+t_on: 0.7
+t_off: 0.3
+temperature: 0.3
+top_p: 0.95
+
+# ── Optimisation ─────────────────────────────────────────────────────────────
+lr: 5.0e-4
+max_grad_norm: 1.0
+lr_warmup_frames: 7.86432e7 # = 300 update steps at this hardware (300 * 2048 * 128)
+
+# ── Rollout ──────────────────────────────────────────────────────────────────
+num_envs: 5120
+num_steps: 128
+num_minibatches: 8
+update_epochs: 8
+num_repeats: 1
+
+# ── Offline training ─────────────────────────────────────────────────────────
+offline_total_timesteps: 2.0e+8
+collect_temperature: 1.0
+return_weight_cap: 5.0
+
+# ── Online DAgger training ───────────────────────────────────────────────────
+online_total_timesteps: 2.0e+8
+dagger_beta_init: 1.0
+dagger_beta_final: 0.385 # 0.9995^1907 at this hardware; auto-resolves per fpu
+dagger_buffer_cycles: 0.76294 # ~200K samples on UCL = ~0.76 update cycles of history
+
+# ── Validation rollouts ──────────────────────────────────────────────────────
+val_interval_frames: 1.0e+6
+val_diffusion_steps: 25
+val_replan_every: 2
+val_steps: 256
+
+# ── Data collection ──────────────────────────────────────────────────────────
+collect_num_steps: 10000000
+collect_num_envs: 128
+ppo_model_type: "ppo_rnn"
+layer_size: 512
+
+# ── Inference ────────────────────────────────────────────────────────────────
+eval_steps: 10000
+eval_num_envs: 32
+
+# ── Checkpointing ────────────────────────────────────────────────────────────
+save_policy: true
+
+# ── Logging ──────────────────────────────────────────────────────────────────
+seed: 42
+use_wandb: true
+wandb_project: "remdm-craftax"
+wandb_entity: "mathis-weil-university-college-london-ucl-"
diff --git a/configs/craftax_exp_a_beta_fix.yaml b/configs/craftax_exp_a_beta_fix.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..20b2048e3c63d54e839dae1a564d3577da058fa4
--- /dev/null
+++ b/configs/craftax_exp_a_beta_fix.yaml
@@ -0,0 +1,90 @@
+# =============================================================================
+# Experiment A-Full — Beta fix only on Full Craftax
+# Target hardware: NVIDIA RTX 4090 (24 GB)
+# Same as exp_a_beta_fix.yaml but on Craftax-Symbolic-v1 (obs_dim 8268, 43
+# actions). num_envs and dagger_buffer_max are dropped to fit the larger obs.
+# =============================================================================
+
+# ── Environment ──────────────────────────────────────────────────────────────
+env_name: "Craftax-Symbolic-v1"
+use_optimistic_resets: false
+optimistic_reset_ratio: 16
+
+# ── Transformer architecture ─────────────────────────────────────────────────
+d_model: 256
+n_heads: 4
+n_layers: 4
+d_ff: 512
+obs_encoder_layers: 2
+obs_encoder_width: 512
+dropout_rate: 0.1
+
+# ── Diffusion ────────────────────────────────────────────────────────────────
+plan_horizon: 32
+diffusion_schedule: "cosine"
+train_sigma: 0.0
+label_smoothing: 0.0
+
+# ── Reverse sampling ─────────────────────────────────────────────────────────
+diffusion_steps: 15
+diffusion_steps_eval: 10
+remask_strategy: "rescale"
+eta: 0.5
+use_loop: true
+t_on: 0.7
+t_off: 0.3
+temperature: 0.5
+top_p: 0.95
+
+# ── Optimisation ─────────────────────────────────────────────────────────────
+lr: 3.0e-4
+max_grad_norm: 1.0
+lr_warmup_steps: 200
+
+# ── Rollout ──────────────────────────────────────────────────────────────────
+num_envs: 2048 # halved from 4096: larger obs per env
+num_steps: 128
+num_minibatches: 8 # minibatch = 2048 * 128 / 8 = 32768
+update_epochs: 8
+num_repeats: 1
+
+# ── Offline training ─────────────────────────────────────────────────────────
+offline_total_timesteps: 1.0e+8
+collect_temperature: 1.0
+return_weight_cap: 5.0
+
+# ── Online DAgger training ───────────────────────────────────────────────────
+online_num_updates: 2000 # harder task, more room to learn
+dagger_beta_init: 1.0
+dagger_beta_decay: 0.9993
+dagger_buffer_max: 200000 # 8268 obs -> ~6.6 GB at 200k
+
+# ── Validation rollouts ──────────────────────────────────────────────────────
+val_interval: 30
+val_diffusion_steps: 50
+val_replan_every: 4
+val_steps: 256
+
+# ── Data collection ──────────────────────────────────────────────────────────
+collect_num_steps: 10000000
+collect_num_envs: 128
+ppo_model_type: "ppo_rnn"
+layer_size: 512
+
+# ── Inference ────────────────────────────────────────────────────────────────
+eval_steps: 10000
+eval_num_envs: 32
+
+# ── Checkpointing ────────────────────────────────────────────────────────────
+save_policy: true
+
+# ── Resume ───────────────────────────────────────────────────────────────────
+resume_checkpoint_path: null
+resume_wandb_run_id: null
+resume_step: null
+
+# ── Logging ──────────────────────────────────────────────────────────────────
+seed: 42
+use_wandb: true
+wandb_project: "remdm-craftax"
+wandb_entity: "mathis-weil-university-college-london-ucl-"
diff --git a/configs/craftax_exp_b_beta_big_model.yaml b/configs/craftax_exp_b_beta_big_model.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ccf979e9a4d98859034de4cd38ad1d992699cbd7
--- /dev/null
+++ b/configs/craftax_exp_b_beta_big_model.yaml
@@ -0,0 +1,90 @@
+# =============================================================================
+# Experiment C-Full — Full recipe on Full Craftax
+# Target hardware: NVIDIA RTX 3090 Ti / 4090 (24 GB)
+# Same as exp_c_full_recipe.yaml but on Craftax-Symbolic-v1 (obs_dim 8268, 43
+# actions). num_envs and dagger_buffer_max are dropped to fit the larger obs.
+# =============================================================================
+
+# ── Environment ──────────────────────────────────────────────────────────────
+env_name: "Craftax-Symbolic-v1"
+use_optimistic_resets: false
+optimistic_reset_ratio: 16
+
+# ── Transformer architecture ─────────────────────────────────────────────────
+d_model: 384
+n_heads: 8
+n_layers: 6
+d_ff: 768
+obs_encoder_layers: 2
+obs_encoder_width: 768
+dropout_rate: 0.1
+
+# ── Diffusion ────────────────────────────────────────────────────────────────
+plan_horizon: 32
+diffusion_schedule: "cosine"
+train_sigma: 0.0
+label_smoothing: 0.0
+
+# ── Reverse sampling ─────────────────────────────────────────────────────────
+diffusion_steps: 25
+diffusion_steps_eval: 10
+remask_strategy: "rescale"
+eta: 0.5
+use_loop: true
+t_on: 0.7
+t_off: 0.3
+temperature: 0.3
+top_p: 0.95
+
+# ── Optimisation ─────────────────────────────────────────────────────────────
+lr: 5.0e-4
+max_grad_norm: 1.0
+lr_warmup_steps: 300
+
+# ── Rollout ──────────────────────────────────────────────────────────────────
+num_envs: 448 # halved from 4096: larger obs per env
+num_steps: 128
+num_minibatches: 8 # minibatch = 2048 * 128 / 8 = 32768
+update_epochs: 8
+num_repeats: 1
+
+# ── Offline training ─────────────────────────────────────────────────────────
+offline_total_timesteps: 1.0e+8
+collect_temperature: 1.0
+return_weight_cap: 5.0
+
+# ── Online DAgger training ───────────────────────────────────────────────────
+online_total_timesteps: 1.0e+8 # harder task, more room to learn
+dagger_beta_init: 1.0
+dagger_beta_decay: 0.9993
+dagger_buffer_max: 200000 # 8268 obs -> ~6.6 GB at 200k
+
+# ── Validation rollouts ──────────────────────────────────────────────────────
+val_interval_frames: 1.0e+6
+val_diffusion_steps: 25 # match diffusion_steps
+val_replan_every: 2
+val_steps: 256
+
+# ── Data collection ──────────────────────────────────────────────────────────
+collect_num_steps: 10000000
+collect_num_envs: 128
+ppo_model_type: "ppo_rnn"
+layer_size: 512
+
+# ── Inference ────────────────────────────────────────────────────────────────
+eval_steps: 10000
+eval_num_envs: 32
+
+# ── Checkpointing ────────────────────────────────────────────────────────────
+save_policy: true
+
+# ── Resume ───────────────────────────────────────────────────────────────────
+resume_checkpoint_path: null
+resume_wandb_run_id: null
+resume_step: null
+
+# ── Logging ──────────────────────────────────────────────────────────────────
+seed: 42
+use_wandb: true
+wandb_project: "remdm-craftax"
+wandb_entity: "mathis-weil-university-college-london-ucl-"
diff --git a/configs/craftax_exp_c_full_recipe.yaml b/configs/craftax_exp_c_full_recipe.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..26b97abf03f74f25e025a454958af4ce74b924d5
--- /dev/null
+++ b/configs/craftax_exp_c_full_recipe.yaml
@@ -0,0 +1,90 @@
+# =============================================================================
+# Experiment B-Full — Beta fix + big model on Full Craftax
+# Target hardware: NVIDIA RTX 3090 Ti (24 GB)
+# Same as exp_b_beta_big_model.yaml but on Craftax-Symbolic-v1 (obs_dim 8268,
+# 43 actions). num_envs and dagger_buffer_max are dropped to fit the larger obs.
+# =============================================================================
+
+# ── Environment ──────────────────────────────────────────────────────────────
+env_name: "Craftax-Symbolic-v1"
+use_optimistic_resets: false
+optimistic_reset_ratio: 16
+
+# ── Transformer architecture ─────────────────────────────────────────────────
+d_model: 384
+n_heads: 8
+n_layers: 6
+d_ff: 768
+obs_encoder_layers: 2
+obs_encoder_width: 768
+dropout_rate: 0.1
+
+# ── Diffusion ────────────────────────────────────────────────────────────────
+plan_horizon: 32
+diffusion_schedule: "cosine"
+train_sigma: 0.0
+label_smoothing: 0.0
+
+# ── Reverse sampling ─────────────────────────────────────────────────────────
+diffusion_steps: 15
+diffusion_steps_eval: 10
+remask_strategy: "rescale"
+eta: 0.5
+use_loop: true
+t_on: 0.7
+t_off: 0.3
+temperature: 0.5
+top_p: 0.95
+
+# ── Optimisation ─────────────────────────────────────────────────────────────
+lr: 3.0e-4
+max_grad_norm: 1.0
+lr_warmup_steps: 200
+
+# ── Rollout ──────────────────────────────────────────────────────────────────
+num_envs: 2048 # halved from 4096: larger obs per env
+num_steps: 128
+num_minibatches: 8 # minibatch = 2048 * 128 / 8 = 32768
+update_epochs: 8
+num_repeats: 1
+
+# ── Offline training ─────────────────────────────────────────────────────────
+offline_total_timesteps: 1.0e+8
+collect_temperature: 1.0
+return_weight_cap: 5.0
+
+# ── Online DAgger training ───────────────────────────────────────────────────
+online_num_updates: 2000 # harder task, more room to learn
+dagger_beta_init: 1.0
+dagger_beta_decay: 0.9993
+dagger_buffer_max: 200000 # 8268 obs -> ~6.6 GB at 200k
+
+# ── Validation rollouts ──────────────────────────────────────────────────────
+val_interval: 30
+val_diffusion_steps: 50
+val_replan_every: 4
+val_steps: 256
+
+# ── Data collection ──────────────────────────────────────────────────────────
+collect_num_steps: 10000000
+collect_num_envs: 128
+ppo_model_type: "ppo_rnn"
+layer_size: 512
+
+# ── Inference ────────────────────────────────────────────────────────────────
+eval_steps: 10000
+eval_num_envs: 32
+
+# ── Checkpointing ────────────────────────────────────────────────────────────
+save_policy: true
+
+# ── Resume ───────────────────────────────────────────────────────────────────
+resume_checkpoint_path: null
+resume_wandb_run_id: null
+resume_step: null
+
+# ── Logging ──────────────────────────────────────────────────────────────────
+seed: 42
+use_wandb: true
+wandb_project: "remdm-craftax"
+wandb_entity: "mathis-weil-university-college-london-ucl-"
diff --git a/configs/craftax_exp_d_1M_model.yaml b/configs/craftax_exp_d_1M_model.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..034e858eca3d5552b098c39632aa1ef798439a83
--- /dev/null
+++ b/configs/craftax_exp_d_1M_model.yaml
@@ -0,0 +1,84 @@
+# =============================================================================
+# Final Full Craftax DAgger — UCL RTX 4090 24 GB - EXP D - 1.1M model
+# Big transformer + full sampling recipe + slow beta decay. Produces the Full
+# Craftax DAgger checkpoint used by the RL fine-tuning ablation suite.
+# =============================================================================
+
+# ── Environment ──────────────────────────────────────────────────────────────
+env_name: "Craftax-Symbolic-v1"
+use_optimistic_resets: false
+optimistic_reset_ratio: 16
+
+# ── Transformer architecture ─────────────────────────────────────────────────
+d_model: 64
+n_heads: 2
+n_layers: 2
+d_ff: 128
+obs_encoder_layers: 1
+obs_encoder_width: 128
+dropout_rate: 0.1
+
+# ── Diffusion ────────────────────────────────────────────────────────────────
+plan_horizon: 32
+diffusion_schedule: "cosine"
+train_sigma: 0.0
+label_smoothing: 0.0
+
+# ── Reverse sampling ─────────────────────────────────────────────────────────
+diffusion_steps: 25
+diffusion_steps_eval: 10
+remask_strategy: "rescale"
+eta: 0.5
+use_loop: true
+t_on: 0.7
+t_off: 0.3
+temperature: 0.3
+top_p: 0.95
+
+# ── Optimisation ─────────────────────────────────────────────────────────────
+lr: 5.0e-4
+max_grad_norm: 1.0
+lr_warmup_frames: 7.86432e7 # = 300 update steps at this hardware (300 * 2048 * 128)
+
+# ── Rollout ──────────────────────────────────────────────────────────────────
+num_envs: 1536
+num_steps: 128
+num_minibatches: 8
+update_epochs: 8
+num_repeats: 1
+
+# ── Offline training ─────────────────────────────────────────────────────────
+offline_total_timesteps: 2.0e+8
+collect_temperature: 1.0
+return_weight_cap: 5.0
+
+# ── Online DAgger training ───────────────────────────────────────────────────
+online_total_timesteps: 2.0e+8
+dagger_beta_init: 1.0
+dagger_beta_final: 0.385 # 0.9995^1907 at this hardware; auto-resolves per fpu
+dagger_buffer_cycles: 0.76294 # ~200K samples on UCL = ~0.76 update cycles of history
+
+# ── Validation rollouts ──────────────────────────────────────────────────────
+val_interval_frames: 1.0e+6
+val_diffusion_steps: 25
+val_replan_every: 2
+val_steps: 256
+
+# ── Data collection ──────────────────────────────────────────────────────────
+collect_num_steps: 10000000
+collect_num_envs: 128
+ppo_model_type: "ppo_rnn"
+layer_size: 512
+
+# ── Inference ────────────────────────────────────────────────────────────────
+eval_steps: 10000
+eval_num_envs: 32
+
+# ── Checkpointing ────────────────────────────────────────────────────────────
+save_policy: true
+
+# ── Logging ──────────────────────────────────────────────────────────────────
+seed: 42
+use_wandb: true
+wandb_project: "remdm-craftax"
+wandb_entity: "mathis-weil-university-college-london-ucl-"
diff --git a/configs/craftax_exp_d_3M_model.yaml b/configs/craftax_exp_d_3M_model.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..1192b9ac0578f64e101479b45926b5a329c73503
--- /dev/null
+++ b/configs/craftax_exp_d_3M_model.yaml
@@ -0,0 +1,84 @@
+# =============================================================================
+# Final Full Craftax DAgger — UCL RTX 4090 24 GB - EXP D - 2.6M model
+# Big transformer + full sampling recipe + slow beta decay. Produces the Full
+# Craftax DAgger checkpoint used by the RL fine-tuning ablation suite.
+# =============================================================================
+
+# ── Environment ──────────────────────────────────────────────────────────────
+env_name: "Craftax-Symbolic-v1"
+use_optimistic_resets: false
+optimistic_reset_ratio: 16
+
+# ── Transformer architecture ─────────────────────────────────────────────────
+d_model: 128
+n_heads: 4
+n_layers: 3
+d_ff: 256
+obs_encoder_layers: 2
+obs_encoder_width: 256
+dropout_rate: 0.1
+
+# ── Diffusion ────────────────────────────────────────────────────────────────
+plan_horizon: 32
+diffusion_schedule: "cosine"
+train_sigma: 0.0
+label_smoothing: 0.0
+
+# ── Reverse sampling ─────────────────────────────────────────────────────────
+diffusion_steps: 25
+diffusion_steps_eval: 10
+remask_strategy: "rescale"
+eta: 0.5
+use_loop: true
+t_on: 0.7
+t_off: 0.3
+temperature: 0.3
+top_p: 0.95
+
+# ── Optimisation ─────────────────────────────────────────────────────────────
+lr: 5.0e-4
+max_grad_norm: 1.0
+lr_warmup_frames: 7.86432e7 # = 300 update steps at this hardware (300 * 2048 * 128)
+
+# ── Rollout ──────────────────────────────────────────────────────────────────
+num_envs: 1240
+num_steps: 128
+num_minibatches: 8
+update_epochs: 8
+num_repeats: 1
+
+# ── Offline training ─────────────────────────────────────────────────────────
+offline_total_timesteps: 2.0e+8
+collect_temperature: 1.0
+return_weight_cap: 5.0
+
+# ── Online DAgger training ───────────────────────────────────────────────────
+online_total_timesteps: 2.0e+8
+dagger_beta_init: 1.0
+dagger_beta_final: 0.385 # 0.9995^1907 at this hardware; auto-resolves per fpu
+dagger_buffer_cycles: 0.76294 # ~200K samples on UCL = ~0.76 update cycles of history
+
+# ── Validation rollouts ──────────────────────────────────────────────────────
+val_interval_frames: 1.0e+6
+val_diffusion_steps: 25
+val_replan_every: 2
+val_steps: 256
+
+# ── Data collection ──────────────────────────────────────────────────────────
+collect_num_steps: 10000000
+collect_num_envs: 128
+ppo_model_type: "ppo_rnn"
+layer_size: 512
+
+# ── Inference ────────────────────────────────────────────────────────────────
+eval_steps: 10000
+eval_num_envs: 32
+
+# ── Checkpointing ────────────────────────────────────────────────────────────
+save_policy: true
+
+# ── Logging ──────────────────────────────────────────────────────────────────
+seed: 42
+use_wandb: true
+wandb_project: "remdm-craftax"
+wandb_entity: "mathis-weil-university-college-london-ucl-"
diff --git a/configs/craftax_exp_d_500K_model.yaml b/configs/craftax_exp_d_500K_model.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ae5cf7146cadcd6d61d8547ec81591f561a415d2
--- /dev/null
+++ b/configs/craftax_exp_d_500K_model.yaml
@@ -0,0 +1,84 @@
+# =============================================================================
+# Final Full Craftax DAgger — UCL RTX 4090 24 GB - EXP D - 0.5M model
+# Big transformer + full sampling recipe + slow beta decay. Produces the Full
+# Craftax DAgger checkpoint used by the RL fine-tuning ablation suite.
+# =============================================================================
+
+# ── Environment ──────────────────────────────────────────────────────────────
+env_name: "Craftax-Symbolic-v1"
+use_optimistic_resets: false
+optimistic_reset_ratio: 16
+
+# ── Transformer architecture ─────────────────────────────────────────────────
+d_model: 32
+n_heads: 2
+n_layers: 1
+d_ff: 64
+obs_encoder_layers: 1
+obs_encoder_width: 64
+dropout_rate: 0.1
+
+# ── Diffusion ────────────────────────────────────────────────────────────────
+plan_horizon: 32
+diffusion_schedule: "cosine"
+train_sigma: 0.0
+label_smoothing: 0.0
+
+# ── Reverse sampling ─────────────────────────────────────────────────────────
+diffusion_steps: 25
+diffusion_steps_eval: 10
+remask_strategy: "rescale"
+eta: 0.5
+use_loop: true
+t_on: 0.7
+t_off: 0.3
+temperature: 0.3
+top_p: 0.95
+
+# ── Optimisation ─────────────────────────────────────────────────────────────
+lr: 5.0e-4
+max_grad_norm: 1.0
+lr_warmup_frames: 7.86432e7 # = 300 update steps at this hardware (300 * 2048 * 128)
+
+# ── Rollout ──────────────────────────────────────────────────────────────────
+num_envs: 2976
+num_steps: 128
+num_minibatches: 8
+update_epochs: 8
+num_repeats: 1
+
+# ── Offline training ─────────────────────────────────────────────────────────
+offline_total_timesteps: 2.0e+8
+collect_temperature: 1.0
+return_weight_cap: 5.0
+
+# ── Online DAgger training ───────────────────────────────────────────────────
+online_total_timesteps: 2.0e+8
+dagger_beta_init: 1.0
+dagger_beta_final: 0.385 # 0.9995^1907 at this hardware; auto-resolves per fpu
+dagger_buffer_cycles: 0.76294 # ~200K samples on UCL = ~0.76 update cycles of history
+
+# ── Validation rollouts ──────────────────────────────────────────────────────
+val_interval_frames: 1.0e+6
+val_diffusion_steps: 25
+val_replan_every: 2
+val_steps: 256
+
+# ── Data collection ──────────────────────────────────────────────────────────
+collect_num_steps: 10000000
+collect_num_envs: 128
+ppo_model_type: "ppo_rnn"
+layer_size: 512
+
+# ── Inference ────────────────────────────────────────────────────────────────
+eval_steps: 10000
+eval_num_envs: 32
+
+# ── Checkpointing ────────────────────────────────────────────────────────────
+save_policy: true
+
+# ── Logging ──────────────────────────────────────────────────────────────────
+seed: 42
+use_wandb: true
+wandb_project: "remdm-craftax"
+wandb_entity: "mathis-weil-university-college-london-ucl-"
diff --git a/configs/craftax_exp_d_7M_model.yaml b/configs/craftax_exp_d_7M_model.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..922ac05b772428036378e97996eddb1d1103b50a
--- /dev/null
+++ b/configs/craftax_exp_d_7M_model.yaml
@@ -0,0 +1,84 @@
+# =============================================================================
+# Final Full Craftax DAgger — UCL RTX 4090 24 GB - EXP D - 6.8M model
+# Big transformer + full sampling recipe + slow beta decay. Produces the Full
+# Craftax DAgger checkpoint used by the RL fine-tuning ablation suite.
+# =============================================================================
+
+# ── Environment ──────────────────────────────────────────────────────────────
+env_name: "Craftax-Symbolic-v1"
+use_optimistic_resets: false
+optimistic_reset_ratio: 16
+
+# ── Transformer architecture ─────────────────────────────────────────────────
+d_model: 256
+n_heads: 4
+n_layers: 4
+d_ff: 512
+obs_encoder_layers: 2
+obs_encoder_width: 512
+dropout_rate: 0.1
+
+# ── Diffusion ────────────────────────────────────────────────────────────────
+plan_horizon: 32
+diffusion_schedule: "cosine"
+train_sigma: 0.0
+label_smoothing: 0.0
+
+# ── Reverse sampling ─────────────────────────────────────────────────────────
+diffusion_steps: 25
+diffusion_steps_eval: 10
+remask_strategy: "rescale"
+eta: 0.5
+use_loop: true
+t_on: 0.7
+t_off: 0.3
+temperature: 0.3
+top_p: 0.95
+
+# ── Optimisation ─────────────────────────────────────────────────────────────
+lr: 5.0e-4
+max_grad_norm: 1.0
+lr_warmup_frames: 7.86432e7 # = 300 update steps at this hardware (300 * 2048 * 128)
+
+# ── Rollout ──────────────────────────────────────────────────────────────────
+num_envs: 832
+num_steps: 128
+num_minibatches: 8
+update_epochs: 8
+num_repeats: 1
+
+# ── Offline training ─────────────────────────────────────────────────────────
+offline_total_timesteps: 2.0e+8
+collect_temperature: 1.0
+return_weight_cap: 5.0
+
+# ── Online DAgger training ───────────────────────────────────────────────────
+online_total_timesteps: 2.0e+8
+dagger_beta_init: 1.0
+dagger_beta_final: 0.385 # 0.9995^1907 at this hardware; auto-resolves per fpu
+dagger_buffer_cycles: 0.76294 # ~200K samples on UCL = ~0.76 update cycles of history
+
+# ── Validation rollouts ──────────────────────────────────────────────────────
+val_interval_frames: 1.0e+6
+val_diffusion_steps: 25
+val_replan_every: 2
+val_steps: 256
+
+# ── Data collection ──────────────────────────────────────────────────────────
+collect_num_steps: 10000000
+collect_num_envs: 128
+ppo_model_type: "ppo_rnn"
+layer_size: 512
+
+# ── Inference ────────────────────────────────────────────────────────────────
+eval_steps: 10000
+eval_num_envs: 32
+
+# ── Checkpointing ────────────────────────────────────────────────────────────
+save_policy: true
+
+# ── Logging ──────────────────────────────────────────────────────────────────
+seed: 42
+use_wandb: true
+wandb_project: "remdm-craftax"
+wandb_entity: "mathis-weil-university-college-london-ucl-"
diff --git a/configs/defaults.yaml b/configs/defaults.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..c08682df1800528c53199c9e59c73b22a875b24a
--- /dev/null
+++ b/configs/defaults.yaml
@@ -0,0 +1,120 @@
+# =============================================================================
+# ReMDM default configuration
+# Override any value via CLI: python main.py --mode offline --lr 1e-4
+# =============================================================================
+
+# ── Environment ──────────────────────────────────────────────────────────────
+env_name: "Craftax-Classic-Symbolic-v1"
+use_optimistic_resets: false
+optimistic_reset_ratio: 16
+
+# ── Transformer architecture ─────────────────────────────────────────────────
+d_model: 256
+n_heads: 4
+n_layers: 4
+d_ff: 512
+obs_encoder_layers: 2
+obs_encoder_width: 512
+dropout_rate: 0.1
+
+# ── Diffusion ────────────────────────────────────────────────────────────────
+plan_horizon: 32
+diffusion_schedule: "cosine" # cosine | linear
+train_sigma: 0.0 # remasking correction during training (0 = standard MDLM)
+label_smoothing: 0.0 # cross-entropy label smoothing (0 = exact ELBO)
+
+# ── Reverse sampling (training, validation, inference) ──────────────────────
+diffusion_steps: 15 # denoising steps T during training
+diffusion_steps_eval: 10 # denoising steps T at inference (--mode inference)
+remask_strategy: "rescale" # rescale | cap | conf
+eta: 0.5 # remasking strength
+use_loop: true # three-phase loop (Algorithm 3)
+t_on: 0.7
+t_off: 0.3
+temperature: 0.5
+top_p: 0.95
+
+# ── Optimisation ─────────────────────────────────────────────────────────────
+lr: 3.0e-4
+max_grad_norm: 1.0
+# lr_warmup_frames is the env-frame-denominated source of truth (invariant
+# under num_envs changes). When set, it overrides lr_warmup_steps via:
+# lr_warmup_steps = lr_warmup_frames // (num_envs * num_steps)
+lr_warmup_frames: null # PRIMARY: env-frame warmup budget
+lr_warmup_steps: 0 # LEGACY: linear warm-up before cosine LR decay (0 = disabled)
+
+# ── Rollout (shared by offline + online) ─────────────────────────────────────
+num_envs: 1024
+num_steps: 64
+num_minibatches: 8
+update_epochs: 4
+num_repeats: 1 # independent training seeds (vmapped)
+
+# ── Offline training (--mode offline) ────────────────────────────────────────
+# num_updates is derived as offline_total_timesteps // (num_envs * num_steps),
+# so the env-frame budget is hardware-portable across num_envs changes.
+offline_total_timesteps: 1.0e+8 # PRIMARY: env-frame budget
+offline_num_updates: null # LEGACY: used only when offline_total_timesteps is unset
+collect_temperature: 1.0 # softmax temperature on PPO logits during data collection
+return_weight_cap: 5.0 # clip ceiling for per-window return weights
+
+# ── Online DAgger training (--mode online) ───────────────────────────────────
+# num_updates is derived as online_total_timesteps // (num_envs * num_steps).
+online_total_timesteps: null # PRIMARY: env-frame budget
+online_num_updates: 1000 # LEGACY: used only when online_total_timesteps is unset
+dagger_beta_init: 1.0 # initial expert mixing probability beta_1
+# dagger_beta_final is the env-frame-invariant source of truth: the target
+# mixing ratio at the end of training. When set, it overrides
+# dagger_beta_decay via: decay = (beta_final / beta_init) ** (1 / num_updates)
+dagger_beta_final: null # PRIMARY: target final beta
+dagger_beta_decay: 0.95 # LEGACY: per-update decay beta_i = beta_init * decay^i
+# dagger_buffer_cycles measures buffer capacity in *update cycles of history*
+# (1 cycle = num_envs * num_steps frames), invariant under num_envs. When
+# set, it overrides dagger_buffer_max via: buffer_max = cycles * fpu
+dagger_buffer_cycles: null # PRIMARY: buffer capacity in update cycles
+dagger_buffer_max: 100000 # LEGACY: max samples in DAgger replay buffer
+# B1: passes per update over the aggregated buffer. Each pass redraws a
+# fresh sample of size samples_per_update from D, so total samples seen
+# per update ≈ n_passes * samples_per_update. null = 1 pass, which
+# matches offline BC's per-update gradient work exactly (fair compute
+# comparison: same num_updates * update_epochs * num_minibatches grad
+# steps). Raise to >1 to trade BC fairness for higher per-update D
+# coverage (e.g. ~⌊|D|/samples_per_update⌋ for full-buffer coverage).
+dagger_train_passes: null
+# B2: deterministic argmax expert (true) vs categorical sampling (false).
+# True keeps the expert mapping s -> a* fixed, removing label noise from D.
+dagger_expert_deterministic: true
+
+# ── Validation rollouts ──────────────────────────────────────────────────────
+# val_interval_frames is the env-frame-invariant source of truth. When set,
+# it overrides val_interval via: val_interval = val_interval_frames // fpu
+val_interval_frames: null # PRIMARY: env-frames between validations
+val_interval: 50 # LEGACY: update steps between validation rollouts
+val_diffusion_steps: 50 # denoising steps during validation
+val_replan_every: 4 # env steps executed per diffusion plan
+val_steps: 128 # total env steps per validation rollout
+
+# ── Data collection (--mode collect) ─────────────────────────────────────────
+collect_num_steps: 10000000
+collect_num_envs: 128
+ppo_model_type: "ppo_rnn" # ppo | ppo_rnn | ppo_rnd
+layer_size: 512
+
+# ── Inference (--mode inference) ─────────────────────────────────────────────
+eval_steps: 10000
+eval_num_envs: 32
+
+# ── Checkpointing ────────────────────────────────────────────────────────────
+save_policy: true
+
+# ── Resume ───────────────────────────────────────────────────────────────────
+resume_checkpoint_path: null # accepts wandb: refs
+resume_wandb_run_id: null # auto-read from checkpoint metadata if null
+resume_step: null # auto-read from checkpoint metadata if null
+
+# ── Logging ──────────────────────────────────────────────────────────────────
+seed: null # random if not set
+use_wandb: true
+wandb_project: "remdm-craftax"
+wandb_entity: "mathis-weil-university-college-london-ucl-"
+wandb_download_dir: null # null = wandb default ./artifacts/
diff --git a/configs/final_classic_qmul.yaml b/configs/final_classic_qmul.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..4d93464d885797ee0b1f1e2404a83d9328cb0df2
--- /dev/null
+++ b/configs/final_classic_qmul.yaml
@@ -0,0 +1,93 @@
+# =============================================================================
+# Final Craftax Classic DAgger — QMUL H200 8 GB partition (seed 43) - EXP B - 9M model
+# Env-frame-matched second seed of final_classic_ucl.yaml.
+# -----------------------------------------------------------------------------
+# Identical to final_classic_ucl.yaml except for num_envs (96 vs 4096) and
+# seed (43 vs 42). All fairness-critical hyperparameters are now denominated
+# in env frames or update cycles, so they are AUTOMATICALLY scaled to this
+# hardware tier by resolve_scaled_hyperparams() — no manual derivation.
+# =============================================================================
+
+# ── Environment ──────────────────────────────────────────────────────────────
+env_name: "Craftax-Classic-Symbolic-v1"
+use_optimistic_resets: false
+optimistic_reset_ratio: 16
+
+# ── Transformer architecture ─────────────────────────────────────────────────
+d_model: 384
+n_heads: 8
+n_layers: 6
+d_ff: 768
+obs_encoder_layers: 2
+obs_encoder_width: 768
+dropout_rate: 0.1
+
+# ── Diffusion ────────────────────────────────────────────────────────────────
+plan_horizon: 32
+diffusion_schedule: "cosine"
+train_sigma: 0.0
+label_smoothing: 0.0
+
+# ── Reverse sampling ─────────────────────────────────────────────────────────
+diffusion_steps: 15
+diffusion_steps_eval: 10
+remask_strategy: "rescale"
+eta: 0.5
+use_loop: true
+t_on: 0.7
+t_off: 0.3
+temperature: 0.5
+top_p: 0.95
+
+# ── Optimisation ─────────────────────────────────────────────────────────────
+lr: 3.0e-4
+max_grad_norm: 1.0
+lr_warmup_frames: 1.048576e8 # IDENTICAL to UCL; resolves to 8533 update steps on 96 envs
+
+# ── Rollout ──────────────────────────────────────────────────────────────────
+num_envs: 96 # VRAM-tuned to ~6.5–7.5 GB; minibatch = 96*128/8 = 1536
+num_steps: 128
+num_minibatches: 8
+update_epochs: 8
+num_repeats: 1
+
+# ── Offline training ─────────────────────────────────────────────────────────
+offline_total_timesteps: 1.0e+8
+collect_temperature: 1.0
+return_weight_cap: 5.0
+
+# ── Online DAgger training ───────────────────────────────────────────────────
+online_total_timesteps: 1.0e+8
+dagger_beta_init: 1.0
+dagger_beta_final: 0.344
+dagger_buffer_cycles: 1.90735
+
+# ── Validation rollouts ──────────────────────────────────────────────────────
+val_interval_frames: 1.0e+6
+val_diffusion_steps: 50
+val_replan_every: 4
+val_steps: 256
+
+# ── Data collection ──────────────────────────────────────────────────────────
+collect_num_steps: 10000000
+collect_num_envs: 128
+ppo_model_type: "ppo_rnn"
+layer_size: 512
+
+# ── Inference ────────────────────────────────────────────────────────────────
+eval_steps: 10000
+eval_num_envs: 32
+
+# ── Checkpointing ────────────────────────────────────────────────────────────
+save_policy: true
+
+# ── Resume ───────────────────────────────────────────────────────────────────
+resume_checkpoint_path: null
+resume_wandb_run_id: null
+resume_step: null
+
+# ── Logging ──────────────────────────────────────────────────────────────────
+seed: 43
+use_wandb: true
+wandb_project: "remdm-craftax"
+wandb_entity: "mathis-weil-university-college-london-ucl-"
diff --git a/configs/final_classic_ucl.yaml b/configs/final_classic_ucl.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..9619cda71db09d0a11c5e1ada69b28c605afcf5f
--- /dev/null
+++ b/configs/final_classic_ucl.yaml
@@ -0,0 +1,84 @@
+# =============================================================================
+# Final Craftax Classic DAgger — UCL RTX 3090 Ti 24 GB - EXP B - 9M model
+# Big transformer + slow beta decay. Produces the Classic DAgger checkpoint
+# used by the RL fine-tuning ablation suite.
+# =============================================================================
+
+# ── Environment ──────────────────────────────────────────────────────────────
+env_name: "Craftax-Classic-Symbolic-v1"
+use_optimistic_resets: false
+optimistic_reset_ratio: 16
+
+# ── Transformer architecture ─────────────────────────────────────────────────
+d_model: 384
+n_heads: 8
+n_layers: 6
+d_ff: 768
+obs_encoder_layers: 2
+obs_encoder_width: 768
+dropout_rate: 0.1
+
+# ── Diffusion ────────────────────────────────────────────────────────────────
+plan_horizon: 32
+diffusion_schedule: "cosine"
+train_sigma: 0.0
+label_smoothing: 0.0
+
+# ── Reverse sampling ─────────────────────────────────────────────────────────
+diffusion_steps: 15
+diffusion_steps_eval: 10
+remask_strategy: "rescale"
+eta: 0.5
+use_loop: true
+t_on: 0.7
+t_off: 0.3
+temperature: 0.5
+top_p: 0.95
+
+# ── Optimisation ─────────────────────────────────────────────────────────────
+lr: 3.0e-4
+max_grad_norm: 1.0
+lr_warmup_frames: 1.048576e8 # = 200 update steps at this hardware (200 * 4096 * 128)
+
+# ── Rollout ──────────────────────────────────────────────────────────────────
+num_envs: 512
+num_steps: 128
+num_minibatches: 8
+update_epochs: 8
+num_repeats: 1
+
+# ── Offline training ─────────────────────────────────────────────────────────
+offline_total_timesteps: 1.0e+8
+collect_temperature: 1.0
+return_weight_cap: 5.0
+
+# ── Online DAgger training ───────────────────────────────────────────────────
+online_total_timesteps: 1.0e+8
+dagger_beta_init: 1.0
+dagger_beta_final: 0.344 # 0.9993^1525 at this hardware; auto-resolves per fpu
+dagger_buffer_cycles: 1.90735 # ~1M samples on UCL = ~1.91 update cycles of history
+
+# ── Validation rollouts ──────────────────────────────────────────────────────
+val_interval_frames: 1.0e+6
+val_diffusion_steps: 50
+val_replan_every: 4
+val_steps: 256
+
+# ── Data collection ──────────────────────────────────────────────────────────
+collect_num_steps: 10000000
+collect_num_envs: 128
+ppo_model_type: "ppo_rnn"
+layer_size: 512
+
+# ── Inference ────────────────────────────────────────────────────────────────
+eval_steps: 10000
+eval_num_envs: 32
+
+# ── Checkpointing ────────────────────────────────────────────────────────────
+save_policy: true
+
+# ── Logging ──────────────────────────────────────────────────────────────────
+seed: 42
+use_wandb: true
+wandb_project: "remdm-craftax"
+wandb_entity: "mathis-weil-university-college-london-ucl-"
diff --git a/configs/final_craftax_qmul.yaml b/configs/final_craftax_qmul.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..9b0ed8951ff8b4db08c41f7b5c78dcf8813cdc52
--- /dev/null
+++ b/configs/final_craftax_qmul.yaml
@@ -0,0 +1,94 @@
+# =============================================================================
+# Final Full Craftax DAgger — QMUL H200 8 GB partition - EXP C - 15M model
+# Env-frame-matched second seed of final_craftax_full_ucl.yaml.
+# -----------------------------------------------------------------------------
+# Identical to final_craftax_full_ucl.yaml except for num_envs (384 vs 2048)
+# and seed (43 vs 42). All fairness-critical hyperparameters are now
+# denominated in env frames or update cycles, so they are AUTOMATICALLY
+# scaled to this hardware tier by resolve_scaled_hyperparams() — no manual
+# derivation.
+# =============================================================================
+
+# ── Environment ──────────────────────────────────────────────────────────────
+env_name: "Craftax-Symbolic-v1"
+use_optimistic_resets: false
+optimistic_reset_ratio: 16
+
+# ── Transformer architecture ─────────────────────────────────────────────────
+d_model: 384
+n_heads: 8
+n_layers: 6
+d_ff: 768
+obs_encoder_layers: 2
+obs_encoder_width: 768
+dropout_rate: 0.1
+
+# ── Diffusion ────────────────────────────────────────────────────────────────
+plan_horizon: 32
+diffusion_schedule: "cosine"
+train_sigma: 0.0
+label_smoothing: 0.0
+
+# ── Reverse sampling ─────────────────────────────────────────────────────────
+diffusion_steps: 25
+diffusion_steps_eval: 10
+remask_strategy: "rescale"
+eta: 0.5
+use_loop: true
+t_on: 0.7
+t_off: 0.3
+temperature: 0.3
+top_p: 0.95
+
+# ── Optimisation ─────────────────────────────────────────────────────────────
+lr: 5.0e-4
+max_grad_norm: 1.0
+lr_warmup_frames: 7.86432e7 # IDENTICAL to UCL; resolves to 1600 update steps on 384 envs
+
+# ── Rollout ──────────────────────────────────────────────────────────────────
+num_envs: 64
+num_steps: 128
+num_minibatches: 8
+update_epochs: 8
+num_repeats: 1
+
+# ── Offline training ─────────────────────────────────────────────────────────
+offline_total_timesteps: 2.0e+8 # IDENTICAL to UCL
+collect_temperature: 1.0
+return_weight_cap: 5.0
+
+# ── Online DAgger training ───────────────────────────────────────────────────
+online_total_timesteps: 2.0e+8 # IDENTICAL to UCL
+dagger_beta_init: 1.0
+dagger_beta_final: 0.385 # IDENTICAL to UCL; resolves to 0.99990623 decay on 384 envs
+dagger_buffer_cycles: 0.76294 # IDENTICAL to UCL; resolves to ~37500 samples on 384 envs
+
+# ── Validation rollouts ──────────────────────────────────────────────────────
+val_interval_frames: 1.0e+6
+val_diffusion_steps: 25
+val_replan_every: 2
+val_steps: 256
+
+# ── Data collection ──────────────────────────────────────────────────────────
+collect_num_steps: 10000000
+collect_num_envs: 128
+ppo_model_type: "ppo_rnn"
+layer_size: 512
+
+# ── Inference ────────────────────────────────────────────────────────────────
+eval_steps: 10000
+eval_num_envs: 32
+
+# ── Checkpointing ────────────────────────────────────────────────────────────
+save_policy: true
+
+# ── Resume ───────────────────────────────────────────────────────────────────
+resume_checkpoint_path: null
+resume_wandb_run_id: null
+resume_step: null
+
+# ── Logging ──────────────────────────────────────────────────────────────────
+seed: 43
+use_wandb: true
+wandb_project: "remdm-craftax"
+wandb_entity: "mathis-weil-university-college-london-ucl-"
diff --git a/configs/final_craftax_ucl.yaml b/configs/final_craftax_ucl.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..5a26a872b45f1843b095e3d199ff550996f071e6
--- /dev/null
+++ b/configs/final_craftax_ucl.yaml
@@ -0,0 +1,84 @@
+# =============================================================================
+# Final Full Craftax DAgger — UCL RTX 4090 24 GB - EXP C - 15M model
+# Big transformer + full sampling recipe + slow beta decay. Produces the Full
+# Craftax DAgger checkpoint used by the RL fine-tuning ablation suite.
+# =============================================================================
+
+# ── Environment ──────────────────────────────────────────────────────────────
+env_name: "Craftax-Symbolic-v1"
+use_optimistic_resets: false
+optimistic_reset_ratio: 16
+
+# ── Transformer architecture ─────────────────────────────────────────────────
+d_model: 384
+n_heads: 8
+n_layers: 6
+d_ff: 768
+obs_encoder_layers: 2
+obs_encoder_width: 768
+dropout_rate: 0.1
+
+# ── Diffusion ────────────────────────────────────────────────────────────────
+plan_horizon: 32
+diffusion_schedule: "cosine"
+train_sigma: 0.0
+label_smoothing: 0.0
+
+# ── Reverse sampling ─────────────────────────────────────────────────────────
+diffusion_steps: 25
+diffusion_steps_eval: 10
+remask_strategy: "rescale"
+eta: 0.5
+use_loop: true
+t_on: 0.7
+t_off: 0.3
+temperature: 0.3
+top_p: 0.95
+
+# ── Optimisation ─────────────────────────────────────────────────────────────
+lr: 5.0e-4
+max_grad_norm: 1.0
+lr_warmup_frames: 7.86432e7 # = 300 update steps at this hardware (300 * 2048 * 128)
+
+# ── Rollout ──────────────────────────────────────────────────────────────────
+num_envs: 448
+num_steps: 128
+num_minibatches: 8
+update_epochs: 8
+num_repeats: 1
+
+# ── Offline training ─────────────────────────────────────────────────────────
+offline_total_timesteps: 2.0e+8
+collect_temperature: 1.0
+return_weight_cap: 5.0
+
+# ── Online DAgger training ───────────────────────────────────────────────────
+online_total_timesteps: 2.0e+8
+dagger_beta_init: 1.0
+dagger_beta_final: 0.385 # 0.9995^1907 at this hardware; auto-resolves per fpu
+dagger_buffer_cycles: 0.76294 # ~200K samples on UCL = ~0.76 update cycles of history
+
+# ── Validation rollouts ──────────────────────────────────────────────────────
+val_interval_frames: 1.0e+6
+val_diffusion_steps: 25
+val_replan_every: 2
+val_steps: 256
+
+# ── Data collection ──────────────────────────────────────────────────────────
+collect_num_steps: 10000000
+collect_num_envs: 128
+ppo_model_type: "ppo_rnn"
+layer_size: 512
+
+# ── Inference ────────────────────────────────────────────────────────────────
+eval_steps: 10000
+eval_num_envs: 32
+
+# ── Checkpointing ────────────────────────────────────────────────────────────
+save_policy: true
+
+# ── Logging ──────────────────────────────────────────────────────────────────
+seed: 42
+use_wandb: true
+wandb_project: "remdm-craftax"
+wandb_entity: "mathis-weil-university-college-london-ucl-"
diff --git a/demo_craftax.ipynb b/demo_craftax.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..e0f5fdb18e02260e540480dcec1bc896b47a5622
--- /dev/null
+++ b/demo_craftax.ipynb
@@ -0,0 +1,803 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "02c2db887e2d",
+ "metadata": {},
+ "source": [
+ "# COMP0258 \u2014 ReMDM Discrete Diffusion Planning in Craftax\n",
+ "\n",
+ "> **Self-contained Colab notebook.** Loads pre-trained checkpoints from a public\n",
+ "> HuggingFace repo and demonstrates live inference, agent visualisation and the\n",
+ "> ReMDM denoising loop. No training is performed inside the notebook.\n",
+ "\n",
+ "**How to use this notebook**\n",
+ "\n",
+ "1. Open in Google Colab (preferably with a GPU runtime).\n",
+ "2. Run all cells top-to-bottom.\n",
+ "3. To test on **unseen inputs**, edit the constants in the *Configuration*\n",
+ " cell (`SEED`, `ENV_NAME`, `EVAL_STEPS`, `EVAL_NUM_ENVS`, `DIFFUSION_STEPS_EVAL`)\n",
+ " and re-run from there. Every Craftax seed produces a procedurally generated\n",
+ " world the agent has never seen.\n",
+ "\n",
+ "**Submission compliance**\n",
+ "\n",
+ "- Cell 1 downloads everything from a single public HuggingFace repo\n",
+ " (`HF_REPO_ID`) \u2014 no authentication required.\n",
+ "- The pre-trained checkpoint is loaded; no training happens here.\n",
+ "- Live inference (Cells 5\u20137) demonstrates that reported numbers reproduce.\n",
+ "- Pre-computed ablation figures (Cells 9\u201310) demonstrate the research finding.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "id": "03d921f981da",
+ "metadata": {},
+ "execution_count": null,
+ "outputs": [],
+ "source": [
+ "# =============================================================================\n",
+ "# CONFIGURATION \u2014 marker can edit any of these and re-run from here\n",
+ "# =============================================================================\n",
+ "\n",
+ "# Public HuggingFace repo containing src/, Craftax_Baselines/, configs/,\n",
+ "# checkpoints/ and pre-computed ablation outputs.\n",
+ "HF_REPO_ID = \"TODO_HF_REPO_ID\"\n",
+ "LOCAL_DIR = \"remdm-craftax\"\n",
+ "\n",
+ "# --- Reproducibility & evaluation knobs ---\n",
+ "SEED = 42\n",
+ "\n",
+ "# \"Craftax-Classic-Symbolic-v1\": 22 achievements, 17 actions (DAgger checkpoint)\n",
+ "# \"Craftax-Symbolic-v1\": 65 achievements, 43 actions (PPO expert only)\n",
+ "ENV_NAME = \"Craftax-Classic-Symbolic-v1\"\n",
+ "\n",
+ "EVAL_STEPS = 2000 # Per-env step budget for live diffusion eval\n",
+ "EVAL_NUM_ENVS = 16 # Parallel environments (lower if Colab OOMs)\n",
+ "DIFFUSION_STEPS_EVAL = 10 # Reverse denoising steps at inference time\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "81c4103fc614",
+ "metadata": {},
+ "source": [
+ "## 1. Project overview\n",
+ "\n",
+ "**Problem.** Plan action sequences in *Craftax*, a JAX-accelerated procedurally\n",
+ "generated open-world survival game (a Crafter reimplementation extended with\n",
+ "NetHack-like mechanics).\n",
+ "\n",
+ "**Approach.** A bidirectional **denoising transformer** generates a\n",
+ "`plan_horizon = 32` action plan by iteratively denoising masked discrete tokens\n",
+ "(MDLM / **ReMDM** \u2014 Wang et al.) conditioned on the current symbolic\n",
+ "observation. At inference, MPC executes one action per plan and re-plans every\n",
+ "step with **historical inpainting**: positions `0 .. hist_len - 1` are locked to\n",
+ "the actions actually taken.\n",
+ "\n",
+ "**Pipeline (offline, before this notebook).**\n",
+ "```\n",
+ "[1] PPO-RNN expert (Craftax_Baselines/ppo_rnn.py)\n",
+ "[2] Offline behaviour cloning (main.py --mode offline)\n",
+ "[3] Online DAgger fine-tuning (main.py --mode online)\n",
+ "[4] RL fine-tuning ablation suite \u2014 25 ablations\n",
+ "```\n",
+ "\n",
+ "**Research question.** Can RL fine-tuning improve a DAgger-pretrained masked\n",
+ "diffusion planner?\n",
+ "\n",
+ "**Headline finding.** DAgger faithfully imitates the PPO expert, but **no RL\n",
+ "ablation** (25 tested across regularisation, optimisation, data and capacity)\n",
+ "meaningfully improves on the DAgger checkpoint. This matches the parallel\n",
+ "finding from our MiniHack codebase, indicating the obstruction is\n",
+ "framework-independent.\n",
+ "\n",
+ "The notebook demonstrates the three things the marker needs to verify:\n",
+ "\n",
+ "| Cell | Demonstrates |\n",
+ "|---|---|\n",
+ "| 5 | Live inference numbers reproduce reported results |\n",
+ "| 6 | The agent meaningfully interacts with Craftax (rewards, achievements) |\n",
+ "| 7 | The ReMDM iterative-unmasking sampler in action |\n",
+ "| 8 | DAgger \u2248 PPO expert (live PPO eval) |\n",
+ "| 9\u201310 | Pre-computed ablation results \u2014 RL doesn't help |\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "id": "449dfd350169",
+ "metadata": {},
+ "execution_count": null,
+ "outputs": [],
+ "source": [
+ "# =============================================================================\n",
+ "# Install pinned dependencies, download the HF repo, set up sys.path\n",
+ "# =============================================================================\n",
+ "\n",
+ "import importlib\n",
+ "import os\n",
+ "import subprocess\n",
+ "import sys\n",
+ "\n",
+ "\n",
+ "def _pip(*pkgs: str) -> None:\n",
+ " subprocess.check_call(\n",
+ " [sys.executable, \"-m\", \"pip\", \"install\", \"-q\", *pkgs],\n",
+ " stdout=subprocess.DEVNULL,\n",
+ " )\n",
+ "\n",
+ "\n",
+ "# Core dependencies (versions match pyproject.toml).\n",
+ "_pip(\n",
+ " \"huggingface_hub>=1.9.1\",\n",
+ " \"craftax>=1.5.0\",\n",
+ " \"flax>=0.12.6\",\n",
+ " \"optax>=0.2.8\",\n",
+ " \"orbax-checkpoint>=0.5\",\n",
+ " \"distrax>=0.1.7\",\n",
+ " \"chex>=0.1.91\",\n",
+ " \"polars>=1.39.3\",\n",
+ " \"orjson>=3.11.8\",\n",
+ " \"pyyaml>=6.0\",\n",
+ ")\n",
+ "\n",
+ "# JAX: keep Colab's preinstalled GPU build if present, otherwise install CPU.\n",
+ "try:\n",
+ " import jax # noqa: F401\n",
+ " if jax.default_backend() == \"gpu\":\n",
+ " pass\n",
+ " else:\n",
+ " raise RuntimeError(\"re-install needed\")\n",
+ "except Exception:\n",
+ " _pip(\"jax>=0.9.2\")\n",
+ " importlib.invalidate_caches()\n",
+ " import jax # noqa: F401\n",
+ "\n",
+ "import craftax # noqa: F401\n",
+ "\n",
+ "backend = jax.default_backend()\n",
+ "device = jax.devices()[0]\n",
+ "print(f\"JAX {jax.__version__} | backend={backend} | device={device}\")\n",
+ "print(f\"Craftax {craftax.__version__}\")\n",
+ "if backend != \"gpu\":\n",
+ " print(\n",
+ " \"WARNING: JAX is running on CPU. Inference will be ~10x slower than GPU.\"\n",
+ " \"\\n In Colab, switch to a GPU runtime via Runtime -> Change runtime type.\"\n",
+ " )\n",
+ "\n",
+ "# ----------------------------------------------------------------------------\n",
+ "# Pull the project tree (code, configs, checkpoints, pre-computed results)\n",
+ "# from a single public HuggingFace repo. No authentication required.\n",
+ "# ----------------------------------------------------------------------------\n",
+ "from huggingface_hub import snapshot_download\n",
+ "\n",
+ "snapshot_path = snapshot_download(repo_id=HF_REPO_ID, local_dir=LOCAL_DIR)\n",
+ "print(f\"Snapshot downloaded to: {snapshot_path}\")\n",
+ "\n",
+ "# Both the parent (for `import src.planners...`) and the Craftax_Baselines/\n",
+ "# subdir (for `from wrappers import ...` inside Craftax_Baselines/ppo_rnn.py)\n",
+ "# must be on sys.path. main.py does the same thing.\n",
+ "for p in (snapshot_path, os.path.join(snapshot_path, \"Craftax_Baselines\")):\n",
+ " if p not in sys.path:\n",
+ " sys.path.insert(0, p)\n",
+ "\n",
+ "os.chdir(snapshot_path)\n",
+ "print(f\"cwd: {os.getcwd()}\")\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "id": "e71fb8bee8cc",
+ "metadata": {},
+ "execution_count": null,
+ "outputs": [],
+ "source": [
+ "# =============================================================================\n",
+ "# Load the pre-trained DAgger diffusion checkpoint\n",
+ "# =============================================================================\n",
+ "\n",
+ "import json\n",
+ "\n",
+ "import jax\n",
+ "import jax.numpy as jnp\n",
+ "from craftax.craftax_env import make_craftax_env_from_name\n",
+ "\n",
+ "from src.planners.model import build_model, load_checkpoint, make_apply_fns\n",
+ "\n",
+ "# Checkpoint inventory bundled with the HF repo.\n",
+ "DIFFUSION_OFFLINE_CKPT = (\n",
+ " \"checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M\"\n",
+ ")\n",
+ "DIFFUSION_ONLINE_CKPT = (\n",
+ " \"checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M\"\n",
+ ")\n",
+ "PPO_CKPT = {\n",
+ " \"Craftax-Classic-Symbolic-v1\": (\n",
+ " \"checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M\"\n",
+ " ),\n",
+ " \"Craftax-Symbolic-v1\": (\n",
+ " \"checkpoints/ppo_agents/Craftax-Symbolic-v1-PPO_RNN-1000M\"\n",
+ " ),\n",
+ "}\n",
+ "\n",
+ "# We pin to the DAgger (online) checkpoint \u2014 that is the headline result.\n",
+ "# The architecture hyperparameters are stored alongside the offline checkpoint\n",
+ "# in `resume_metadata.json` (online and offline share the same architecture).\n",
+ "with open(os.path.join(DIFFUSION_OFFLINE_CKPT, \"resume_metadata.json\")) as f:\n",
+ " META = json.load(f)\n",
+ "ARCH_CFG = META[\"config_snapshot\"]\n",
+ "\n",
+ "# The diffusion checkpoints are trained on Classic Craftax. The denoiser's\n",
+ "# action head dimension is fixed at training time (num_actions = 17 for Classic),\n",
+ "# so we cannot transfer it to Full Craftax (43 actions). PPO covers both.\n",
+ "DIFFUSION_ENV_NAME = \"Craftax-Classic-Symbolic-v1\"\n",
+ "\n",
+ "env_init = make_craftax_env_from_name(DIFFUSION_ENV_NAME, auto_reset=True)\n",
+ "env_params_init = env_init.default_params\n",
+ "NUM_ACTIONS = int(env_init.action_space(env_params_init).n)\n",
+ "OBS_DIM = int(env_init.observation_space(env_params_init).shape[0])\n",
+ "PLAN_HORIZON = int(ARCH_CFG[\"PLAN_HORIZON\"])\n",
+ "\n",
+ "print(f\"DiffusionEnv : {DIFFUSION_ENV_NAME}\")\n",
+ "print(f\" obs_dim : {OBS_DIM}\")\n",
+ "print(f\" num_actions: {NUM_ACTIONS}\")\n",
+ "print(f\"Architecture : d_model={ARCH_CFG['D_MODEL']} n_layers={ARCH_CFG['N_LAYERS']} \"\n",
+ " f\"n_heads={ARCH_CFG['N_HEADS']} plan_horizon={PLAN_HORIZON}\")\n",
+ "\n",
+ "model = build_model(ARCH_CFG, NUM_ACTIONS)\n",
+ "apply_eval, _ = make_apply_fns(model)\n",
+ "diffusion_params = load_checkpoint(\n",
+ " model,\n",
+ " jax.random.PRNGKey(SEED),\n",
+ " OBS_DIM,\n",
+ " PLAN_HORIZON,\n",
+ " DIFFUSION_ONLINE_CKPT,\n",
+ ")\n",
+ "\n",
+ "n_params = sum(int(p.size) for p in jax.tree.leaves(diffusion_params))\n",
+ "print(f\"Loaded DAgger checkpoint: {n_params / 1e6:.2f}M parameters\")\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "id": "e27e94345edf",
+ "metadata": {},
+ "execution_count": null,
+ "outputs": [],
+ "source": [
+ "# =============================================================================\n",
+ "# CELL 5 (PRIORITY) \u2014 live inference on `EVAL_NUM_ENVS` parallel envs\n",
+ "# =============================================================================\n",
+ "#\n",
+ "# This calls the *exact* production inference code path used by\n",
+ "# `python main.py --mode inference`. The marker can change ENV_NAME / SEED /\n",
+ "# EVAL_STEPS in Cell 1 to test on entirely fresh procedurally generated worlds.\n",
+ "#\n",
+ "# Reported numbers (DAgger online checkpoint, EVAL_STEPS=10 000, 32 envs,\n",
+ "# Craftax-Classic-Symbolic-v1):\n",
+ "# mean episode return \u2248 10.4 (compare with PPO expert \u2248 10.4)\n",
+ "# score ceiling: 22 (one point per achievement)\n",
+ "# -----------------------------------------------------------------------------\n",
+ "\n",
+ "from src.planners.inference import run_inference\n",
+ "\n",
+ "if \"Classic\" not in ENV_NAME:\n",
+ " print(\n",
+ " f\"ENV_NAME={ENV_NAME!r}: no diffusion checkpoint exists for Full Craftax\\n\"\n",
+ " f\"(action vocabularies differ: 17 vs 43). Skipping diffusion eval \u2014 see\\n\"\n",
+ " f\"Cell 8 for the matching PPO expert evaluation on this environment.\"\n",
+ " )\n",
+ "else:\n",
+ " inference_cfg = {\n",
+ " # Architecture (read from checkpoint metadata)\n",
+ " **{k: v for k, v in ARCH_CFG.items() if k.isupper()},\n",
+ " # Marker-controlled\n",
+ " \"ENV_NAME\": ENV_NAME,\n",
+ " \"SEED\": SEED,\n",
+ " \"EVAL_STEPS\": EVAL_STEPS,\n",
+ " \"EVAL_NUM_ENVS\": EVAL_NUM_ENVS,\n",
+ " \"DIFFUSION_STEPS_EVAL\": DIFFUSION_STEPS_EVAL,\n",
+ " # Always disable W&B inside the notebook\n",
+ " \"USE_WANDB\": False,\n",
+ " \"CHECKPOINT_PATH\": DIFFUSION_ONLINE_CKPT,\n",
+ " }\n",
+ " run_inference(inference_cfg)\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "id": "fa0d0926061a",
+ "metadata": {},
+ "execution_count": null,
+ "outputs": [],
+ "source": [
+ "# =============================================================================\n",
+ "# CELL 6 (PRIORITY) \u2014 visualise the diffusion planner acting in Craftax\n",
+ "# =============================================================================\n",
+ "#\n",
+ "# Re-runs the MPC + historical-inpainting loop on a small batch of envs and\n",
+ "# captures per-step rewards / achievements / actions for plotting.\n",
+ "# -----------------------------------------------------------------------------\n",
+ "\n",
+ "import numpy as np\n",
+ "import jax\n",
+ "import jax.numpy as jnp\n",
+ "import matplotlib.pyplot as plt\n",
+ "\n",
+ "from src.diffusion.sampling import sample_plan_inpainting\n",
+ "\n",
+ "VIZ_NUM_ENVS = 4\n",
+ "VIZ_STEPS = 600\n",
+ "\n",
+ "if \"Classic\" not in ENV_NAME:\n",
+ " print(f\"Skipping behaviour viz: no diffusion checkpoint for {ENV_NAME!r}.\")\n",
+ "else:\n",
+ " viz_env = make_craftax_env_from_name(DIFFUSION_ENV_NAME, auto_reset=True)\n",
+ " viz_params = viz_env.default_params\n",
+ "\n",
+ " rng = jax.random.PRNGKey(SEED + 1)\n",
+ " rng, env_rng = jax.random.split(rng)\n",
+ " obs0, state0 = jax.vmap(viz_env.reset, in_axes=(0, None))(\n",
+ " jax.random.split(env_rng, VIZ_NUM_ENVS), viz_params,\n",
+ " )\n",
+ " history0 = jnp.full((VIZ_NUM_ENVS, PLAN_HORIZON), NUM_ACTIONS, dtype=jnp.int32)\n",
+ " hist_len0 = jnp.zeros(VIZ_NUM_ENVS, dtype=jnp.int32)\n",
+ " env_indices = jnp.arange(VIZ_NUM_ENVS)\n",
+ "\n",
+ " @jax.jit\n",
+ " def viz_step(carry, _):\n",
+ " obs, state, rng, history, hist_len = carry\n",
+ " rng, plan_rng, env_rng = jax.random.split(rng, 3)\n",
+ "\n",
+ " # Reset history when plan window is exhausted (matches inference.py).\n",
+ " seq_full = hist_len >= PLAN_HORIZON\n",
+ " hist_len = jnp.where(seq_full, 0, hist_len)\n",
+ " history = jnp.where(seq_full[:, None], NUM_ACTIONS, history)\n",
+ "\n",
+ " plan = sample_plan_inpainting(\n",
+ " apply_eval, diffusion_params, plan_rng, obs,\n",
+ " history, hist_len, NUM_ACTIONS, PLAN_HORIZON,\n",
+ " DIFFUSION_STEPS_EVAL,\n",
+ " ARCH_CFG[\"TEMPERATURE\"], ARCH_CFG[\"TOP_P\"],\n",
+ " )\n",
+ " action = jnp.take_along_axis(plan, hist_len[:, None], axis=-1).squeeze(-1)\n",
+ " history = history.at[env_indices, hist_len].set(action)\n",
+ " hist_len = hist_len + 1\n",
+ "\n",
+ " obs_next, state_next, reward, done, _ = jax.vmap(\n",
+ " viz_env.step, in_axes=(0, 0, 0, None),\n",
+ " )(jax.random.split(env_rng, VIZ_NUM_ENVS), state, action, viz_params)\n",
+ "\n",
+ " hist_len = jnp.where(done, 0, hist_len)\n",
+ " history = jnp.where(done[:, None], NUM_ACTIONS, history)\n",
+ " return (\n",
+ " (obs_next, state_next, rng, history, hist_len),\n",
+ " (action, reward, done, state_next.achievements),\n",
+ " )\n",
+ "\n",
+ " print(f\"Rolling out {VIZ_NUM_ENVS} agents for {VIZ_STEPS} steps...\")\n",
+ " _, (acts, rews, dones, achs) = jax.lax.scan(\n",
+ " viz_step, (obs0, state0, rng, history0, hist_len0), jnp.arange(VIZ_STEPS),\n",
+ " )\n",
+ " acts_np = np.array(acts) # [T, E]\n",
+ " rews_np = np.array(rews) # [T, E]\n",
+ " achs_np = np.array(achs) # [T, E, num_ach]\n",
+ " dones_np = np.array(dones) # [T, E]\n",
+ "\n",
+ " # First-life cumulative reward + unlock count per agent\n",
+ " fig, axes = plt.subplots(1, 2, figsize=(12, 4))\n",
+ " cum_reward = np.cumsum(rews_np, axis=0)\n",
+ " unlock_count = achs_np.sum(axis=-1)\n",
+ " for e in range(VIZ_NUM_ENVS):\n",
+ " end = np.where(dones_np[:, e])[0]\n",
+ " cutoff = int(end[0]) + 1 if len(end) > 0 else VIZ_STEPS\n",
+ " axes[0].plot(np.arange(cutoff), cum_reward[:cutoff, e], label=f\"agent {e}\")\n",
+ " axes[1].plot(np.arange(cutoff), unlock_count[:cutoff, e], label=f\"agent {e}\")\n",
+ " axes[0].set_title(\"Cumulative reward (first life)\")\n",
+ " axes[0].set_xlabel(\"env step\"); axes[0].set_ylabel(\"cumulative reward\")\n",
+ " axes[1].set_title(\"Achievements unlocked\")\n",
+ " axes[1].set_xlabel(\"env step\"); axes[1].set_ylabel(\"# unique achievements\")\n",
+ " for ax in axes:\n",
+ " ax.legend(loc=\"best\", fontsize=8); ax.grid(alpha=0.3)\n",
+ " plt.tight_layout(); plt.show()\n",
+ "\n",
+ " # Action histogram (which actions does the planner actually use?)\n",
+ " from craftax.craftax_classic.constants import Action as ClassicAction\n",
+ " action_names = [a.name for a in ClassicAction]\n",
+ " counts = np.bincount(acts_np.flatten(), minlength=NUM_ACTIONS)\n",
+ " order = np.argsort(-counts)\n",
+ " fig, ax = plt.subplots(figsize=(11, 3.5))\n",
+ " ax.bar(range(NUM_ACTIONS), counts[order])\n",
+ " ax.set_xticks(range(NUM_ACTIONS))\n",
+ " ax.set_xticklabels([action_names[i] for i in order], rotation=70, ha=\"right\", fontsize=8)\n",
+ " ax.set_title(f\"Action usage over {VIZ_STEPS * VIZ_NUM_ENVS} env steps\")\n",
+ " ax.set_ylabel(\"count\"); ax.grid(alpha=0.3, axis=\"y\")\n",
+ " plt.tight_layout(); plt.show()\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "id": "522c414d6a5a",
+ "metadata": {},
+ "execution_count": null,
+ "outputs": [],
+ "source": [
+ "# =============================================================================\n",
+ "# CELL 7 (PRIORITY) \u2014 visualise the ReMDM iterative unmasking process\n",
+ "# =============================================================================\n",
+ "#\n",
+ "# `sample_plan_inpainting` is JIT-compiled via lax.scan, so to capture the\n",
+ "# intermediate token sequences we re-implement its body as a Python loop. This\n",
+ "# is faithful to the production sampler \u2014 only the loop construct differs.\n",
+ "# -----------------------------------------------------------------------------\n",
+ "\n",
+ "import numpy as np\n",
+ "import jax\n",
+ "import jax.numpy as jnp\n",
+ "import matplotlib.pyplot as plt\n",
+ "from matplotlib.colors import ListedColormap\n",
+ "\n",
+ "if \"Classic\" not in ENV_NAME:\n",
+ " print(f\"Skipping denoising viz: no diffusion checkpoint for {ENV_NAME!r}.\")\n",
+ "else:\n",
+ " DENOISE_STEPS = 12 # show enough steps to see iterative unmasking\n",
+ " MASK_ID = NUM_ACTIONS\n",
+ "\n",
+ " # Sample one observation from the env to condition on.\n",
+ " viz_env = make_craftax_env_from_name(DIFFUSION_ENV_NAME, auto_reset=True)\n",
+ " viz_params = viz_env.default_params\n",
+ " one_rng = jax.random.PRNGKey(SEED + 2)\n",
+ " obs1, _state1 = viz_env.reset(one_rng, viz_params)\n",
+ " obs_b = obs1[None, :] # [1, obs_dim]\n",
+ "\n",
+ " seq = jnp.full((1, PLAN_HORIZON), MASK_ID, dtype=jnp.int32)\n",
+ " history = jnp.full((1, PLAN_HORIZON), MASK_ID, dtype=jnp.int32)\n",
+ " hist_len = jnp.zeros((1,), dtype=jnp.int32)\n",
+ " rng = jax.random.PRNGKey(SEED + 3)\n",
+ " temperature = float(ARCH_CFG[\"TEMPERATURE\"])\n",
+ " top_p = float(ARCH_CFG[\"TOP_P\"])\n",
+ "\n",
+ " trace = [np.array(seq[0])] # row 0 = fully masked\n",
+ "\n",
+ " # Mirror the body of src.diffusion.sampling.sample_plan_inpainting._step\n",
+ " for step in range(1, DENOISE_STEPS + 1):\n",
+ " rng, model_rng, sample_rng, remask_rng = jax.random.split(rng, 4)\n",
+ " ratio = step / DENOISE_STEPS\n",
+ " t_tensor = jnp.full((1,), 1.0 - ratio)\n",
+ " logits = apply_eval(diffusion_params, obs_b, seq, t_tensor, model_rng) / max(temperature, 1e-8)\n",
+ "\n",
+ " # Nucleus filtering\n",
+ " probs = jax.nn.softmax(logits, axis=-1)\n",
+ " sorted_idx = jnp.argsort(-probs, axis=-1)\n",
+ " sorted_p = jnp.take_along_axis(probs, sorted_idx, axis=-1)\n",
+ " cutoff = jnp.cumsum(sorted_p, axis=-1) - sorted_p\n",
+ " inv_idx = jnp.argsort(sorted_idx, axis=-1)\n",
+ " nucleus_mask = jnp.take_along_axis(cutoff >= top_p, inv_idx, axis=-1)\n",
+ " logits = jnp.where(nucleus_mask, -jnp.inf, logits)\n",
+ "\n",
+ " preds = jax.random.categorical(sample_rng, logits, axis=-1)\n",
+ " conf = jnp.take_along_axis(\n",
+ " jax.nn.softmax(logits, axis=-1), preds[..., None], axis=-1,\n",
+ " ).squeeze(-1)\n",
+ " num_unmask = max(1, int(PLAN_HORIZON * ratio))\n",
+ " sorted_conf = jnp.sort(conf, axis=-1)[..., ::-1]\n",
+ " thresh = sorted_conf[0, num_unmask - 1]\n",
+ " seq_new = jnp.where(conf < thresh, MASK_ID, preds)\n",
+ "\n",
+ " # ReMDM-style remasking (matches sample_plan_inpainting)\n",
+ " remask_prob = 0.15 * (1.0 - ratio)\n",
+ " do_remask = (\n",
+ " (jax.random.uniform(remask_rng, seq_new.shape) < remask_prob)\n",
+ " & (seq_new != MASK_ID)\n",
+ " )\n",
+ " seq_new = jnp.where(do_remask, MASK_ID, seq_new)\n",
+ "\n",
+ " # Lock historical prefix (here: empty, so no-op)\n",
+ " pos = jnp.broadcast_to(jnp.arange(PLAN_HORIZON)[None, :], (1, PLAN_HORIZON))\n",
+ " seq_new = jnp.where(pos < hist_len[:, None], history, seq_new)\n",
+ "\n",
+ " seq = seq_new\n",
+ " trace.append(np.array(seq[0]))\n",
+ "\n",
+ " trace = np.stack(trace) # [steps+1, plan_horizon]\n",
+ "\n",
+ " # Heatmap: rows = denoising step, columns = action position\n",
+ " # Mask cells coloured grey, action cells coloured by token id (viridis).\n",
+ " masked = trace == MASK_ID\n",
+ " fig, ax = plt.subplots(figsize=(12, 5))\n",
+ " cmap = plt.cm.viridis.copy()\n",
+ " display = np.where(masked, np.nan, trace.astype(float))\n",
+ " im = ax.imshow(\n",
+ " display, aspect=\"auto\", cmap=cmap, vmin=0, vmax=NUM_ACTIONS - 1,\n",
+ " interpolation=\"nearest\",\n",
+ " )\n",
+ " # Overlay grey for masked cells\n",
+ " ax.imshow(\n",
+ " np.where(masked, 1.0, np.nan), aspect=\"auto\",\n",
+ " cmap=ListedColormap([\"#dddddd\"]), vmin=0, vmax=1, interpolation=\"nearest\",\n",
+ " )\n",
+ " ax.set_xlabel(\"plan position (action token)\")\n",
+ " ax.set_ylabel(\"denoising step\")\n",
+ " ax.set_title(\n",
+ " f\"ReMDM iterative unmasking ({DENOISE_STEPS} steps, plan_horizon={PLAN_HORIZON})\"\n",
+ " \"\\nGrey = MASK token, colour = sampled action ID\"\n",
+ " )\n",
+ " cbar = plt.colorbar(im, ax=ax, fraction=0.04, pad=0.02)\n",
+ " cbar.set_label(\"action ID\")\n",
+ " plt.tight_layout(); plt.show()\n",
+ "\n",
+ " n_masked = masked.sum(axis=1)\n",
+ " print(f\"Masked tokens per step: {list(n_masked)}\")\n",
+ " print(f\" step 0 (start): {int(n_masked[0])}/{PLAN_HORIZON} masked\")\n",
+ " print(f\" step {DENOISE_STEPS} (final): {int(n_masked[-1])}/{PLAN_HORIZON} masked\")\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "id": "8bc67b9b704a",
+ "metadata": {},
+ "execution_count": null,
+ "outputs": [],
+ "source": [
+ "# =============================================================================\n",
+ "# CELL 8 \u2014 PPO expert baseline (loaded live, evaluated on `ENV_NAME`)\n",
+ "# =============================================================================\n",
+ "#\n",
+ "# DAgger trains the diffusion planner to imitate this expert. We expect the\n",
+ "# DAgger return (Cell 5) to be close to the PPO expert return on Classic.\n",
+ "# For Full Craftax (where no diffusion checkpoint exists), this is the only\n",
+ "# evaluation that can run.\n",
+ "# -----------------------------------------------------------------------------\n",
+ "\n",
+ "import os\n",
+ "import numpy as np\n",
+ "import jax\n",
+ "import jax.numpy as jnp\n",
+ "import orbax.checkpoint as ocp\n",
+ "from orbax.checkpoint import checkpoint_utils\n",
+ "from craftax.craftax_env import make_craftax_env_from_name\n",
+ "\n",
+ "from src.planners.ppo import build_ppo_network, PPOAgent\n",
+ "\n",
+ "PPO_NUM_ENVS = EVAL_NUM_ENVS\n",
+ "PPO_EVAL_STEPS = min(EVAL_STEPS, 1500) # PPO eval is fast \u2014 bound it for snappy demo\n",
+ "\n",
+ "ppo_path = os.path.abspath(PPO_CKPT[ENV_NAME])\n",
+ "ppo_env = make_craftax_env_from_name(ENV_NAME, auto_reset=True)\n",
+ "ppo_params_env = ppo_env.default_params\n",
+ "ppo_num_actions = int(ppo_env.action_space(ppo_params_env).n)\n",
+ "ppo_obs_dim = int(ppo_env.observation_space(ppo_params_env).shape[0])\n",
+ "\n",
+ "# Build the PPO-RNN network and an abstract param pytree from a dummy init.\n",
+ "PPO_LAYER_SIZE = 512\n",
+ "ppo_net = build_ppo_network(\"ppo_rnn\", ppo_num_actions, PPO_LAYER_SIZE,\n",
+ " {\"LAYER_SIZE\": PPO_LAYER_SIZE})\n",
+ "_dummy_x = (jnp.zeros((1, PPO_NUM_ENVS, ppo_obs_dim)),\n",
+ " jnp.zeros((1, PPO_NUM_ENVS)))\n",
+ "_abstract_params = ppo_net.init(\n",
+ " jax.random.PRNGKey(0),\n",
+ " jnp.zeros((PPO_NUM_ENVS, PPO_LAYER_SIZE)),\n",
+ " _dummy_x,\n",
+ ")\n",
+ "\n",
+ "# Restore params only (the on-disk checkpoint also contains opt_state which we\n",
+ "# don't need for inference). `partial_restore=True` lets us read just `params`,\n",
+ "# and `construct_restore_args` is required on orbax >= 0.11 to provide sharding.\n",
+ "_restore_args = checkpoint_utils.construct_restore_args({\"params\": _abstract_params})\n",
+ "with ocp.CheckpointManager(ppo_path) as _mgr:\n",
+ " _step = _mgr.latest_step()\n",
+ " _restored = _mgr.restore(\n",
+ " _step,\n",
+ " args=ocp.args.PyTreeRestore(\n",
+ " item={\"params\": _abstract_params},\n",
+ " restore_args=_restore_args,\n",
+ " partial_restore=True,\n",
+ " ),\n",
+ " )\n",
+ "print(f\"Loaded PPO_RNN checkpoint from '{ppo_path}' (step {_step})\")\n",
+ "\n",
+ "ppo_agent = PPOAgent(\n",
+ " network=ppo_net,\n",
+ " params=_restored[\"params\"],\n",
+ " model_type=\"ppo_rnn\",\n",
+ " layer_size=PPO_LAYER_SIZE,\n",
+ ")\n",
+ "\n",
+ "rng = jax.random.PRNGKey(SEED + 100)\n",
+ "rng, env_rng = jax.random.split(rng)\n",
+ "obs, state = jax.vmap(ppo_env.reset, in_axes=(0, None))(\n",
+ " jax.random.split(env_rng, PPO_NUM_ENVS), ppo_params_env,\n",
+ ")\n",
+ "hidden0 = ppo_agent.init_hidden(PPO_NUM_ENVS)\n",
+ "done0 = jnp.zeros(PPO_NUM_ENVS, dtype=bool)\n",
+ "\n",
+ "\n",
+ "@jax.jit\n",
+ "def ppo_step(carry, _):\n",
+ " obs, state, hidden, done, rng = carry\n",
+ " rng, act_rng, env_rng = jax.random.split(rng, 3)\n",
+ " action, hidden = ppo_agent.act(obs, done, hidden, act_rng, temperature=1.0)\n",
+ " obs_next, state_next, reward, done_next, _ = jax.vmap(\n",
+ " ppo_env.step, in_axes=(0, 0, 0, None),\n",
+ " )(jax.random.split(env_rng, PPO_NUM_ENVS), state, action, ppo_params_env)\n",
+ " return (obs_next, state_next, hidden, done_next, rng), (reward, done_next, state_next.achievements)\n",
+ "\n",
+ "\n",
+ "print(f\"Running PPO expert: {PPO_NUM_ENVS} envs x {PPO_EVAL_STEPS} steps on {ENV_NAME}...\")\n",
+ "_, (rewards, dones, achievements) = jax.lax.scan(\n",
+ " ppo_step, (obs, state, hidden0, done0, rng), jnp.arange(PPO_EVAL_STEPS),\n",
+ ")\n",
+ "\n",
+ "rewards_np = np.array(rewards)\n",
+ "dones_np = np.array(dones)\n",
+ "ach_np = np.array(achievements)\n",
+ "\n",
+ "# First-life evaluation (matches src/planners/inference.py convention)\n",
+ "ep_returns = np.zeros(PPO_NUM_ENVS)\n",
+ "ep_unlocks = np.zeros(PPO_NUM_ENVS, dtype=int)\n",
+ "for i in range(PPO_NUM_ENVS):\n",
+ " deaths = np.where(dones_np[:, i])[0]\n",
+ " end = int(deaths[0]) if len(deaths) > 0 else PPO_EVAL_STEPS - 1\n",
+ " ep_returns[i] = rewards_np[: end + 1, i].sum()\n",
+ " ep_unlocks[i] = int(ach_np[: end + 1, i].max(axis=0).sum())\n",
+ "\n",
+ "print()\n",
+ "print(f\"PPO expert mean return : {ep_returns.mean():.2f} (best={ep_returns.max():.2f})\")\n",
+ "print(f\"PPO expert mean unlocks: {ep_unlocks.mean():.2f} achievements\")\n",
+ "if \"Classic\" in ENV_NAME:\n",
+ " print()\n",
+ " print(\n",
+ " \"Compare with the DAgger diffusion result printed in Cell 5. On the\\n\"\n",
+ " \"DAgger checkpoint with EVAL_STEPS=10000 / 32 envs, both methods score\\n\"\n",
+ " \"\u2248 10.4 / 22 in the paper \u2014 i.e. DAgger has fully imitated the expert.\"\n",
+ " )\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "059f96f8e951",
+ "metadata": {},
+ "source": [
+ "## 2. RL fine-tuning ablation suite (pre-computed)\n",
+ "\n",
+ "Once DAgger had reached the expert ceiling, we ran a **25-ablation suite** of\n",
+ "RL fine-tuning interventions trying to push past it. Every figure and table\n",
+ "below was generated offline by `experiments/rl_finetuning/run_ablations.py`\n",
+ "and shipped in the HF repo at\n",
+ "`experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/`.\n",
+ "\n",
+ "The four groups tested in the ablation suite:\n",
+ "\n",
+ "| Group | Hypothesis tested |\n",
+ "|---|---|\n",
+ "| **A** Regularisation | KL penalty, EWC, LLRD, LoRA, mixed replay, hard trust region |\n",
+ "| **B** Optimisation | t-curriculum, entropy bonus, PCGrad, advantage clip, normalised adv, BC-on-wins, low-t |\n",
+ "| **C** Capacity | head-only / FFN-only / attention-only / frozen backbone / top-k layer ablations |\n",
+ "| **D** Data | reward filtering, action diversity, running stats, learned reward model |\n",
+ "\n",
+ "**Headline finding.** No ablation meaningfully improves on the DAgger\n",
+ "checkpoint (best \u0394-score over baseline RL is +0.18 \u2014 within noise). Multiple\n",
+ "ablations *collapse* (\u0394 \u2248 \u221210) \u2014 see the figures below.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "id": "522024988f95",
+ "metadata": {},
+ "execution_count": null,
+ "outputs": [],
+ "source": [
+ "# =============================================================================\n",
+ "# CELL 10 \u2014 Pre-computed ablation figures\n",
+ "# =============================================================================\n",
+ "\n",
+ "import os\n",
+ "from IPython.display import Image, display, Markdown\n",
+ "\n",
+ "ABLATION_DIR = \"experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures\"\n",
+ "\n",
+ "KEY_FIGURES = [\n",
+ " (\n",
+ " \"group_comparison.png\",\n",
+ " \"**Group comparison** \u2014 final score by ablation group. Group A \"\n",
+ " \"(regularisation) clusters tightly around the baseline; Group C \"\n",
+ " \"(capacity ablation) collapses.\",\n",
+ " ),\n",
+ " (\n",
+ " \"score_delta_over_baseline_rl.png\",\n",
+ " \"**\u0394-score over baseline RL** \u2014 every ablation, sorted. The best \"\n",
+ " \"improvement is < +0.2; many collapse to \u221210.\",\n",
+ " ),\n",
+ " (\n",
+ " \"gradient_alignment.png\",\n",
+ " \"**Gradient alignment** \u2014 cosine similarity between the BC and RL \"\n",
+ " \"loss gradients during fine-tuning. Frequently negative, hinting at \"\n",
+ " \"the underlying conflict.\",\n",
+ " ),\n",
+ " (\n",
+ " \"gradient_conflict_map.png\",\n",
+ " \"**Per-layer gradient conflict** \u2014 where in the network the BC and \"\n",
+ " \"RL losses pull in opposite directions.\",\n",
+ " ),\n",
+ "]\n",
+ "\n",
+ "for fname, caption in KEY_FIGURES:\n",
+ " path = os.path.join(ABLATION_DIR, fname)\n",
+ " if os.path.exists(path):\n",
+ " display(Markdown(caption))\n",
+ " display(Image(filename=path))\n",
+ " else:\n",
+ " display(Markdown(f\"_Missing figure: `{fname}`_\"))\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "id": "a2a77a2dd8b1",
+ "metadata": {},
+ "execution_count": null,
+ "outputs": [],
+ "source": [
+ "# =============================================================================\n",
+ "# CELL 11 \u2014 Ablation results tables\n",
+ "# =============================================================================\n",
+ "\n",
+ "import os\n",
+ "import polars as pl\n",
+ "\n",
+ "TABLES_DIR = \"experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables\"\n",
+ "\n",
+ "main_df = pl.read_csv(os.path.join(TABLES_DIR, \"main_results.csv\")).sort(\n",
+ " \"Final_Score\", descending=True,\n",
+ ")\n",
+ "print(\"Main ablation results (sorted by Final_Score, ceiling = 22):\")\n",
+ "print(main_df)\n",
+ "\n",
+ "verdict_df = pl.read_csv(os.path.join(TABLES_DIR, \"hypothesis_verdict.csv\"))\n",
+ "print()\n",
+ "print(\"Hypothesis verdicts:\")\n",
+ "print(verdict_df.select([\"Ablation\", \"Group\", \"Result\", \"Conclusion\"]))\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "d2dbe7e3e2ab",
+ "metadata": {},
+ "source": [
+ "## 3. Conclusions\n",
+ "\n",
+ "1. **DAgger works.** The online DAgger checkpoint imitates the PPO expert on\n",
+ " Craftax-Classic to within noise (Cells 5 vs 8).\n",
+ "2. **RL fine-tuning does not help.** Across 25 ablations spanning\n",
+ " regularisation, optimisation, capacity and data interventions, no method\n",
+ " meaningfully improves on the DAgger checkpoint. The best \u0394-score over\n",
+ " baseline RL is +0.18, well within seed-to-seed variance.\n",
+ "3. **Several ablations actively collapse** (Group C: capacity restrictions,\n",
+ " plus a handful of regularisers). The collapse pattern correlates with deep\n",
+ " gradient flow into the backbone \u2014 see `gradient_conflict_map.png`.\n",
+ "4. **Framework independence.** The same finding holds in our parallel MiniHack\n",
+ " PyTorch codebase, suggesting the obstruction is fundamental to RL\n",
+ " fine-tuning of masked discrete diffusion planners rather than a Craftax-\n",
+ " or JAX-specific quirk.\n",
+ "\n",
+ "The full project documentation lives in the HF repo's `README.md`.\n"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "name": "python",
+ "pygments_lexer": "ipython3"
+ },
+ "colab": {
+ "provenance": [],
+ "toc_visible": true
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
\ No newline at end of file
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_breakdown.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_breakdown.png
new file mode 100644
index 0000000000000000000000000000000000000000..ae25bc6f337351b600ecda1e1a5b14ae84136294
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_breakdown.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8fd0272fca80a41f963407732ae253f29dfc5854ff9b5b523c211b8a8ac66331
+size 237515
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_action_diversity.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_action_diversity.png
new file mode 100644
index 0000000000000000000000000000000000000000..d85fc59485d8d3203b39382031135bfef0107aef
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_action_diversity.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b6eb2e3b57cef0f52ca8fa9d11dc5e519ecde4384b90d905b4718e1d05eadd45
+size 135933
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_advantage_clip.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_advantage_clip.png
new file mode 100644
index 0000000000000000000000000000000000000000..5eecfdbbc611b0e0a968c706cbf343962d47940f
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_advantage_clip.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0553ad5e2ed113a0600cf73553275fdf01638c2dd0b6cd665c57a1f2581e601c
+size 136022
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_attention_only.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_attention_only.png
new file mode 100644
index 0000000000000000000000000000000000000000..ef13baa8ae189aa3695de18d3fc24f53b4237ba9
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_attention_only.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ee83c2fe557294320c78ee8bd8d92933122f995dfce1f46d34f498475ace334a
+size 133347
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_baseline_rl.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_baseline_rl.png
new file mode 100644
index 0000000000000000000000000000000000000000..38d08b7f0af5b65d493be0b07e04f2ae097a307f
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_baseline_rl.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:010af9f97e110e07ed29784d80969f2597a9f6081db620f98184417c747fcd85
+size 134511
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_bc_wins.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_bc_wins.png
new file mode 100644
index 0000000000000000000000000000000000000000..54ce17a38510fe2603fcc78dfc2be1efb6b0e832
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_bc_wins.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3cd85f3e810b8cdec773f371386c4397fbe634b44bb63f8ad16443cad14b3f0d
+size 134632
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_entropy_bonus.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_entropy_bonus.png
new file mode 100644
index 0000000000000000000000000000000000000000..d188974a2c9c3f0e2ec14514d6bf0157ab1d4bbc
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_entropy_bonus.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ba4b34ba6d7ccb2c2583e6f35ead7e466e62e4bffe9e91dc8867111fc67dfc9e
+size 135210
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_ewc.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_ewc.png
new file mode 100644
index 0000000000000000000000000000000000000000..1dea7616447eb3e181068bcc926ae0f5be56b5e7
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_ewc.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:045be833a526d9877e900bc3de3485d00fc61584be64714c4b622a068c2ad98a
+size 133720
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_ffn_only.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_ffn_only.png
new file mode 100644
index 0000000000000000000000000000000000000000..0e8c71e560be90d75b5ef267da6079e2bde0c26c
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_ffn_only.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5b46b767bf0d0bd138521dc1ee3ee06dfacab3ed2598249cecd02b400fc794a2
+size 134017
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_frozen_backbone.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_frozen_backbone.png
new file mode 100644
index 0000000000000000000000000000000000000000..f4b9e1aa61e878b91f6b8549ec38d08cb136bf20
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_frozen_backbone.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0037a7256ba08b31ca81611a36e03f7db615f407268a5c9bda66206d0790fc4c
+size 133997
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_gradient_surgery.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_gradient_surgery.png
new file mode 100644
index 0000000000000000000000000000000000000000..f01365700450f3bcc1506d5eb398be8ba836840e
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_gradient_surgery.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:95226a3364a619c349d37befd6d610bac5cfd6c98a921b3c6a18c254f4062c62
+size 136173
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_head_only.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_head_only.png
new file mode 100644
index 0000000000000000000000000000000000000000..6e4d81bae783d01720d6c7b612215b4f878db6ac
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_head_only.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:21166bb529dcc724b805f446c624b46cb70791e3431e4df3a04dd34ff465e93f
+size 132643
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_kl_penalty.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_kl_penalty.png
new file mode 100644
index 0000000000000000000000000000000000000000..177ab427c9e63a8f1596ba9607134a40cb15f4af
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_kl_penalty.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4217cf6a98602c5ba3fb7f9f468d59a931f3943f861796569cfa8561e0e59f0e
+size 134951
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_layer_ablation_top1.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_layer_ablation_top1.png
new file mode 100644
index 0000000000000000000000000000000000000000..14e1eea3b3a6488ac6f243f517b45012094583c3
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_layer_ablation_top1.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:65bf5071ea5c277181cfc6f9c6494ceda57ed254a1f5745ccded973efe560a40
+size 136013
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_layer_ablation_top2.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_layer_ablation_top2.png
new file mode 100644
index 0000000000000000000000000000000000000000..af5d3ecb91d23766c51ab17c067f10cdda253f66
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_layer_ablation_top2.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:228cbd1d905792fbbd2f44f9c8d02a7f055a3f8e87fee5b98663e9e2a4c7064c
+size 136454
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_layer_ablation_top3.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_layer_ablation_top3.png
new file mode 100644
index 0000000000000000000000000000000000000000..fc7c5c586be53a47dca999c646cee7419a292bcb
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_layer_ablation_top3.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:34e2cf8ba9dcddb86bc358846d268a4ae32918c67afb9b377f5b8dd04a6a97c9
+size 136433
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_llrd.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_llrd.png
new file mode 100644
index 0000000000000000000000000000000000000000..f599187d2fe3a4cdf58744abdf4c884b17cce5be
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_llrd.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:607afe4b1edd9232abb71f5ac2b13ccf755dfda876fa079bcf38a416e22a700d
+size 133481
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_lora.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_lora.png
new file mode 100644
index 0000000000000000000000000000000000000000..1942fc5c35db229d2ff22780b0abca18048ac957
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_lora.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f393911fcbebdf0975cf856f5f89bfe29a910aa8810b5b0b05dee6bd36b61402
+size 131535
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_low_t.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_low_t.png
new file mode 100644
index 0000000000000000000000000000000000000000..f9568e099c123d55278f1e5314fe2ef24b7f1061
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_low_t.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b12fecc1bb6a52be0eead06d376d03181a2d28d0948bf4e0e4d13d4b2e9b29f6
+size 134018
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_mixed_replay.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_mixed_replay.png
new file mode 100644
index 0000000000000000000000000000000000000000..f6235b007ada5c9f9069eea0e00181b2755fbf8f
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_mixed_replay.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6ea22bc4ba454bf3f4afc12bd2ae2ff24a1f6c282cbdfe7ff101e2709cd04d5a
+size 135081
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_normalized_adv.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_normalized_adv.png
new file mode 100644
index 0000000000000000000000000000000000000000..295022834e613ef957c6dc99bf5e21e4728b475d
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_normalized_adv.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:23ab90a189e70f55f1326cc7063994fbe79a32bee0be5b5639cf53df529848eb
+size 135355
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_reward_filtering.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_reward_filtering.png
new file mode 100644
index 0000000000000000000000000000000000000000..f058eb12d16734126f207a8567d1b1770e6a0ac6
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_reward_filtering.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:40c1b9b3e09a33b9ed9b263c4c41618dca9eebe30ec932da9e4ec7b1ecca130e
+size 135850
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_reward_model.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_reward_model.png
new file mode 100644
index 0000000000000000000000000000000000000000..0127bb0678d56c85c721d01808c79bbe6accd0dc
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_reward_model.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:95a05f55fca14efed58d57a100adb007668040ad953763d13907116759e40ee5
+size 135375
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_running_stats.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_running_stats.png
new file mode 100644
index 0000000000000000000000000000000000000000..e5a9b983bb43a723e6e69aae5a08e63ed3095636
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_running_stats.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:08fc0df1830dfe974d5bbc8f57e059a3cc4a84098f753485a187ffa1ff35a88b
+size 135119
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_t_curriculum.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_t_curriculum.png
new file mode 100644
index 0000000000000000000000000000000000000000..d209d50a2d598096f8c81f6c62a960f0c2f658a8
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_t_curriculum.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d8bb7bf6adca273d06d8588f1100ba8f84d3e5a2248593491a9ee1a301458495
+size 134524
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_trust_region_kl.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_trust_region_kl.png
new file mode 100644
index 0000000000000000000000000000000000000000..0954e1490afcf90d22e2055057e8a84dcba8e03b
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_trust_region_kl.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:626f6de40a928e34b3ad2f0cd08ae5b77ba21b1197407e585a0f5c324a245930
+size 134955
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_action_diversity.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_action_diversity.png
new file mode 100644
index 0000000000000000000000000000000000000000..ee708ca6476ff21de43ddab4abc77785f0310777
Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_action_diversity.png differ
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_advantage_clip.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_advantage_clip.png
new file mode 100644
index 0000000000000000000000000000000000000000..129ff30f5cf3b46db5cca159ce52d84f8e9800e5
Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_advantage_clip.png differ
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_attention_only.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_attention_only.png
new file mode 100644
index 0000000000000000000000000000000000000000..36068856843e3857eb4975dfd17fd06a5dfe1a6c
Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_attention_only.png differ
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_baseline_rl.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_baseline_rl.png
new file mode 100644
index 0000000000000000000000000000000000000000..19a40d01dbf279ed94893f1e626a49fd460c9ca7
Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_baseline_rl.png differ
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_bc_wins.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_bc_wins.png
new file mode 100644
index 0000000000000000000000000000000000000000..88de1b2c885f882babf55bd87397f8e78cf45c89
Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_bc_wins.png differ
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_entropy_bonus.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_entropy_bonus.png
new file mode 100644
index 0000000000000000000000000000000000000000..bd4b2b7a1123b9287bf7238329980bd2775c6f7f
Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_entropy_bonus.png differ
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_ewc.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_ewc.png
new file mode 100644
index 0000000000000000000000000000000000000000..7efa39242bfc08a395e1831c00568f382d0ca3fe
Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_ewc.png differ
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_ffn_only.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_ffn_only.png
new file mode 100644
index 0000000000000000000000000000000000000000..acba80473c0b6d4881c1391a75f2f550c2d34c9e
Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_ffn_only.png differ
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_frozen_backbone.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_frozen_backbone.png
new file mode 100644
index 0000000000000000000000000000000000000000..c51249a3d2c7c916850fbe056e6b365c38a73387
Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_frozen_backbone.png differ
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_gradient_surgery.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_gradient_surgery.png
new file mode 100644
index 0000000000000000000000000000000000000000..1b4cb141d1636ff4fc02f9138546950d8c6332a2
Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_gradient_surgery.png differ
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_head_only.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_head_only.png
new file mode 100644
index 0000000000000000000000000000000000000000..4d2adaa28336334cba06b8f4365c14d28cea529a
Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_head_only.png differ
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_kl_penalty.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_kl_penalty.png
new file mode 100644
index 0000000000000000000000000000000000000000..fc4769ee998d9edac22f40227cc61457ab3cbbfb
Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_kl_penalty.png differ
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_layer_ablation_top1.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_layer_ablation_top1.png
new file mode 100644
index 0000000000000000000000000000000000000000..d22be76bdd684170decbd1f3dafe3c13502f1681
Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_layer_ablation_top1.png differ
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_layer_ablation_top2.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_layer_ablation_top2.png
new file mode 100644
index 0000000000000000000000000000000000000000..181da76c1e9ecb4c9e0c03f9541e4d1a88525541
Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_layer_ablation_top2.png differ
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_layer_ablation_top3.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_layer_ablation_top3.png
new file mode 100644
index 0000000000000000000000000000000000000000..110b4a7725dd336559050d6f4cc87fa5bb637181
Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_layer_ablation_top3.png differ
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_llrd.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_llrd.png
new file mode 100644
index 0000000000000000000000000000000000000000..9caabed2ea78dd42e31f3df05004355f487753c7
Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_llrd.png differ
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_lora.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_lora.png
new file mode 100644
index 0000000000000000000000000000000000000000..5f999818a0cb9be3c9474e49f0f318ff815fe53d
Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_lora.png differ
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_low_t.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_low_t.png
new file mode 100644
index 0000000000000000000000000000000000000000..c016e18c2e344eb2bf7776cd751ca529510d7639
Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_low_t.png differ
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_mixed_replay.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_mixed_replay.png
new file mode 100644
index 0000000000000000000000000000000000000000..71057088f8d96954a4fc812f97e02d567a797541
Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_mixed_replay.png differ
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_normalized_adv.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_normalized_adv.png
new file mode 100644
index 0000000000000000000000000000000000000000..ea5e0361f672dac5663ea03df1e48edddb6b0c2f
Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_normalized_adv.png differ
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_reward_filtering.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_reward_filtering.png
new file mode 100644
index 0000000000000000000000000000000000000000..1809c8b70f0fd8e26d4a1250cf317e0d2d81a2cf
Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_reward_filtering.png differ
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_reward_model.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_reward_model.png
new file mode 100644
index 0000000000000000000000000000000000000000..bdd0ee3503ebd9e6c59395c419f5e84d0c339837
Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_reward_model.png differ
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_running_stats.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_running_stats.png
new file mode 100644
index 0000000000000000000000000000000000000000..b4fbcbeb8ca00342f2a3ab6f246ece8551da36ad
Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_running_stats.png differ
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_t_curriculum.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_t_curriculum.png
new file mode 100644
index 0000000000000000000000000000000000000000..4c57d92d70c79cf53e83e9629440de601bd75f49
Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_t_curriculum.png differ
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_trust_region_kl.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_trust_region_kl.png
new file mode 100644
index 0000000000000000000000000000000000000000..ff0393dae46f4cc7e29597d51a45eaf533ea4cbe
Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_trust_region_kl.png differ
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_action_diversity.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_action_diversity.png
new file mode 100644
index 0000000000000000000000000000000000000000..6d8696b9e44f1db2b977a26cbb5bfe38ea5c81e3
Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_action_diversity.png differ
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_advantage_clip.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_advantage_clip.png
new file mode 100644
index 0000000000000000000000000000000000000000..39d10ea2e59878ee247530135185ca6de6220c4a
Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_advantage_clip.png differ
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_attention_only.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_attention_only.png
new file mode 100644
index 0000000000000000000000000000000000000000..b57a609a5459e7a886505f2b654dde658ebc13ae
Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_attention_only.png differ
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_baseline_rl.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_baseline_rl.png
new file mode 100644
index 0000000000000000000000000000000000000000..ca35ed0b03d7884c41359cd4a8d9906c99bde7f1
Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_baseline_rl.png differ
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_bc_wins.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_bc_wins.png
new file mode 100644
index 0000000000000000000000000000000000000000..c807f2ec06c7711025cbfb2288d4f48a3b61b70f
Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_bc_wins.png differ
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_entropy_bonus.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_entropy_bonus.png
new file mode 100644
index 0000000000000000000000000000000000000000..ef44ae7d1ef7b1417bd8311e97b7ccd03a649609
Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_entropy_bonus.png differ
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_ewc.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_ewc.png
new file mode 100644
index 0000000000000000000000000000000000000000..2c0b581e8e7db09c02f307e8d8cabd4c7d776dc1
Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_ewc.png differ
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_ffn_only.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_ffn_only.png
new file mode 100644
index 0000000000000000000000000000000000000000..287e7381fe8b29a2e3b635a7702ff4cf7f686b30
Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_ffn_only.png differ
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_frozen_backbone.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_frozen_backbone.png
new file mode 100644
index 0000000000000000000000000000000000000000..7b277ce28e77355f0de5d4b5bd2ed9161eb526bf
Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_frozen_backbone.png differ
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_gradient_surgery.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_gradient_surgery.png
new file mode 100644
index 0000000000000000000000000000000000000000..e450b59c96a6e4365aee4e92f265b7fd23e5d87e
Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_gradient_surgery.png differ
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_head_only.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_head_only.png
new file mode 100644
index 0000000000000000000000000000000000000000..24fd8a10dd306101f58efce8dee764fd3d1c56f5
Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_head_only.png differ
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_kl_penalty.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_kl_penalty.png
new file mode 100644
index 0000000000000000000000000000000000000000..6b485a1452785663af7b00f62b525288dd6cf15c
Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_kl_penalty.png differ
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_layer_ablation_top1.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_layer_ablation_top1.png
new file mode 100644
index 0000000000000000000000000000000000000000..5a16c5c3f2eb18257562530e115b58986cbcffd4
Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_layer_ablation_top1.png differ
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_layer_ablation_top2.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_layer_ablation_top2.png
new file mode 100644
index 0000000000000000000000000000000000000000..2df327df133d32935fd907a7d02b55e337524834
Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_layer_ablation_top2.png differ
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_layer_ablation_top3.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_layer_ablation_top3.png
new file mode 100644
index 0000000000000000000000000000000000000000..a927f431403e939da2e30a8c8c800db6d76ab118
Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_layer_ablation_top3.png differ
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_llrd.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_llrd.png
new file mode 100644
index 0000000000000000000000000000000000000000..55d9b11216a8679683633aa3dd313931f2bd7aa7
Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_llrd.png differ
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_lora.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_lora.png
new file mode 100644
index 0000000000000000000000000000000000000000..1174438b7bef57230fd9b7cf57f8275c2f6b6d61
Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_lora.png differ
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_low_t.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_low_t.png
new file mode 100644
index 0000000000000000000000000000000000000000..349e754da6aedbde1fc2e53171c87ff88482005c
Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_low_t.png differ
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_mixed_replay.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_mixed_replay.png
new file mode 100644
index 0000000000000000000000000000000000000000..74745cf99ad778ec4cb70f4bc570cc563001e1dd
Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_mixed_replay.png differ
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_normalized_adv.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_normalized_adv.png
new file mode 100644
index 0000000000000000000000000000000000000000..530add18e2e869cb2ccf971fc551eb5882076dea
Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_normalized_adv.png differ
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_reward_filtering.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_reward_filtering.png
new file mode 100644
index 0000000000000000000000000000000000000000..21761b876fe67bdaa19b75bf1564270d0050b0db
Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_reward_filtering.png differ
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_reward_model.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_reward_model.png
new file mode 100644
index 0000000000000000000000000000000000000000..9c66a14c73934017ab40ffd470a6fd227fe50e95
Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_reward_model.png differ
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_running_stats.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_running_stats.png
new file mode 100644
index 0000000000000000000000000000000000000000..ad874a6efce1e124fcfd0b9d0114d6c2c0f263be
Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_running_stats.png differ
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_t_curriculum.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_t_curriculum.png
new file mode 100644
index 0000000000000000000000000000000000000000..be37abd7469a25d2c2f5fe479453f65c34f77801
Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_t_curriculum.png differ
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_trust_region_kl.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_trust_region_kl.png
new file mode 100644
index 0000000000000000000000000000000000000000..22e33a613e0cc9395e5ecee7d5e790a803cd543f
Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_trust_region_kl.png differ
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/js_divergence_comparison.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/js_divergence_comparison.png
new file mode 100644
index 0000000000000000000000000000000000000000..34184d1b5fb7bbc48f3d74a4dc2ef03ee2870aa4
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/js_divergence_comparison.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:51c1fddc8f4a7784df2e35debaa257c9d55ac0a90c4a6c5d71f64ff7651b863f
+size 132243
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_action_diversity.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_action_diversity.png
new file mode 100644
index 0000000000000000000000000000000000000000..6bd7baf5d304e9c1a5b0807bebcee5cb5cb1e6be
Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_action_diversity.png differ
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_advantage_clip.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_advantage_clip.png
new file mode 100644
index 0000000000000000000000000000000000000000..8bf58cfd3883ba9b7b97de40ea20199a2cac2e33
Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_advantage_clip.png differ
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_attention_only.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_attention_only.png
new file mode 100644
index 0000000000000000000000000000000000000000..40c4120101996395b502e3e1cf096c083e1729d0
Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_attention_only.png differ
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_baseline_rl.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_baseline_rl.png
new file mode 100644
index 0000000000000000000000000000000000000000..54781c1097b8436e9eff006bac97f9ca4644a2f8
Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_baseline_rl.png differ
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_bc_wins.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_bc_wins.png
new file mode 100644
index 0000000000000000000000000000000000000000..bf6f528d3d3a02453f4095fcbebbde3086e63cb5
Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_bc_wins.png differ
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_entropy_bonus.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_entropy_bonus.png
new file mode 100644
index 0000000000000000000000000000000000000000..e21cc53573af9aceb3c7a96beeeb88f005108f59
Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_entropy_bonus.png differ
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_ewc.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_ewc.png
new file mode 100644
index 0000000000000000000000000000000000000000..ba72945ed3281fbb9d7829fd506de35c62be6209
Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_ewc.png differ
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_ffn_only.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_ffn_only.png
new file mode 100644
index 0000000000000000000000000000000000000000..98d546858d2bfe866846c1844ed0353a41ddd3a9
Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_ffn_only.png differ
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_frozen_backbone.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_frozen_backbone.png
new file mode 100644
index 0000000000000000000000000000000000000000..edde7351c85dc57c9b8b52a263a0223000bf97f6
Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_frozen_backbone.png differ
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_gradient_surgery.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_gradient_surgery.png
new file mode 100644
index 0000000000000000000000000000000000000000..e9c1ef9a286401d66076b0194bd13e93fa2c0941
Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_gradient_surgery.png differ
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_head_only.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_head_only.png
new file mode 100644
index 0000000000000000000000000000000000000000..ab70b7e7298233ba645e7e52486539cad8b61731
Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_head_only.png differ
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_kl_penalty.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_kl_penalty.png
new file mode 100644
index 0000000000000000000000000000000000000000..7dbbbf441d9c11ebbbf93a7e11431f20051e53d0
Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_kl_penalty.png differ
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_layer_ablation_top1.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_layer_ablation_top1.png
new file mode 100644
index 0000000000000000000000000000000000000000..ce0f8bd73a3f4f5072dbcfb74c01db3706be91c9
Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_layer_ablation_top1.png differ
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_layer_ablation_top2.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_layer_ablation_top2.png
new file mode 100644
index 0000000000000000000000000000000000000000..c4c599671379907d4b2e45ff605ca84a7310f826
Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_layer_ablation_top2.png differ
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_layer_ablation_top3.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_layer_ablation_top3.png
new file mode 100644
index 0000000000000000000000000000000000000000..a2c3cb311f2903aff489c407b0f50ca8f286b845
Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_layer_ablation_top3.png differ
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_llrd.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_llrd.png
new file mode 100644
index 0000000000000000000000000000000000000000..57b26ca326283b92e16cf92badfbc963eb014e46
Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_llrd.png differ
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_lora.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_lora.png
new file mode 100644
index 0000000000000000000000000000000000000000..3d2d95a97d6ddb73c44d4349a7bc75e38aeea0e3
Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_lora.png differ
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_low_t.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_low_t.png
new file mode 100644
index 0000000000000000000000000000000000000000..339b001947a9f72934b56a76802764f991c00600
Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_low_t.png differ
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_mixed_replay.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_mixed_replay.png
new file mode 100644
index 0000000000000000000000000000000000000000..a24c85f65bb593824ff3055ca3af1920559c72e4
Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_mixed_replay.png differ
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_normalized_adv.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_normalized_adv.png
new file mode 100644
index 0000000000000000000000000000000000000000..f70f3dfd0fca60b6ee3fc3a149952d94a04cc439
Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_normalized_adv.png differ
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_reward_filtering.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_reward_filtering.png
new file mode 100644
index 0000000000000000000000000000000000000000..9b9aa5ef1b944a8fce413073533b661943015025
Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_reward_filtering.png differ
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_reward_model.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_reward_model.png
new file mode 100644
index 0000000000000000000000000000000000000000..45938e25c038d8e9b602348a48e250555de0720c
Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_reward_model.png differ
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_running_stats.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_running_stats.png
new file mode 100644
index 0000000000000000000000000000000000000000..4af89311f43c6ed5cdf4407e014609fcb5f5dece
Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_running_stats.png differ
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_t_curriculum.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_t_curriculum.png
new file mode 100644
index 0000000000000000000000000000000000000000..33a5a938dc96392309c40ee7b99bab4534675373
Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_t_curriculum.png differ
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_trust_region_kl.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_trust_region_kl.png
new file mode 100644
index 0000000000000000000000000000000000000000..e5ef83d76ef24422728311a0fca94e447d24bf9d
Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_trust_region_kl.png differ
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/cka_similarity.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/cka_similarity.png
new file mode 100644
index 0000000000000000000000000000000000000000..c310bc6a1f385c626d8ce2be6fc205a0b5d39720
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/cka_similarity.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:44593d43a80c26b41e478d39e5c8236623d7dc299ac5c5e283ce40b2b812c5c8
+size 290679
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_action_diversity.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_action_diversity.png
new file mode 100644
index 0000000000000000000000000000000000000000..1a1c5a6a594192dfb3a0aa60e3d7fa6053a6c16d
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_action_diversity.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:692f105d51077ce63ff6ea982247e27efdb98ce427da2228f19b8f0b577ffec7
+size 355986
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_advantage_clip.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_advantage_clip.png
new file mode 100644
index 0000000000000000000000000000000000000000..b5229ea5c5aed3d86ae71cad29a24caa247fe302
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_advantage_clip.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:41ddfc2337c144d5dec6162cbf570ecfb817c54793fbadf746ef0291f644cc80
+size 343397
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_attention_only.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_attention_only.png
new file mode 100644
index 0000000000000000000000000000000000000000..8ec172260b948d454e965867c0efeabedab915a6
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_attention_only.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0bf083bb7799ea94d910229a66c05088bdfb7ef77d7dd1e83c5cdb9afebc597e
+size 262074
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_baseline_rl.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_baseline_rl.png
new file mode 100644
index 0000000000000000000000000000000000000000..6f29c6913b605cfe0a1983382a437bcf9f9476a8
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_baseline_rl.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:68a6a41afcb4bf2c0ae0eb887359949bba8c68437fd46944c580864e8e28557f
+size 331413
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_bc_wins.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_bc_wins.png
new file mode 100644
index 0000000000000000000000000000000000000000..b1425c9617b77c55160a0efc3a31cd9a38f13d28
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_bc_wins.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7c0061a7fcc6c7fb1cda0200c5c24395683010bc46bcb9dae8ac53a29d8a9549
+size 341009
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_entropy_bonus.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_entropy_bonus.png
new file mode 100644
index 0000000000000000000000000000000000000000..905e694ab7c2369e0400d6ae1240bbf630535108
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_entropy_bonus.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c5728b730211256bf431ff6a72fcf75379b1903d67b73e34644ef291a4de0c58
+size 330705
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_ewc.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_ewc.png
new file mode 100644
index 0000000000000000000000000000000000000000..32bde7e5357f6366bea382a99b90c1419cd2490c
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_ewc.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cefbf2bce56057bf07b2fb5d6cf206f317b352fc0fc1db9d157c208799c11046
+size 321337
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_ffn_only.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_ffn_only.png
new file mode 100644
index 0000000000000000000000000000000000000000..89d9d84b2885991a2556055894330fa4f3d821da
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_ffn_only.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:46eeab5469caa72bf4ba8ca960fc6c9ca9da98d1423a554bfe869e25521980f2
+size 305749
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_frozen_backbone.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_frozen_backbone.png
new file mode 100644
index 0000000000000000000000000000000000000000..63527eed5b03bb9aa7a52ae7a97e917f537a57b4
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_frozen_backbone.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9760f15a8db02091cbaad275466a743788a3adbf61eaf0ec417924e9c456ec41
+size 257882
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_gradient_surgery.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_gradient_surgery.png
new file mode 100644
index 0000000000000000000000000000000000000000..b070484d580a3fb8b253fb73be50ca35a47aa5ff
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_gradient_surgery.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:44bf4ae5051aef0ab769ea7b5a072be9cb28b47303351514dd9015c8392a966c
+size 345006
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_head_only.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_head_only.png
new file mode 100644
index 0000000000000000000000000000000000000000..aebe0a3c7acddf4bafd93bd63698619644f4fa03
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_head_only.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:84eefc42a64d0c8a513b2db52f0bd5f662d2ae4f77cae5a0d29cc7bca7013f86
+size 257202
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_kl_penalty.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_kl_penalty.png
new file mode 100644
index 0000000000000000000000000000000000000000..c39046c1e5d4398b5ba125ecee85c56b17aee799
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_kl_penalty.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4a034184498328b065391d527931e23cab380ea96ad1fe011fd9be5a67c0c98c
+size 333918
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_layer_ablation_top1.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_layer_ablation_top1.png
new file mode 100644
index 0000000000000000000000000000000000000000..3e68662069c4c25c534b53e04fa0d99a4ccee073
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_layer_ablation_top1.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:99269fe8832261c43e0c4271ebf670f4b48f0aff72b76f9c6496bc45b3f14f16
+size 286476
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_layer_ablation_top2.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_layer_ablation_top2.png
new file mode 100644
index 0000000000000000000000000000000000000000..192ca2a6df38b875f6939b94ec42fad247799e42
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_layer_ablation_top2.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a18329b9f876bc06de9193263269bef3a2eeed8cea78c171cf2b66a505145f0c
+size 291612
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_layer_ablation_top3.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_layer_ablation_top3.png
new file mode 100644
index 0000000000000000000000000000000000000000..e6a516ce85db660480522aefa6420a28e42b3150
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_layer_ablation_top3.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:706c7c8edc372b72ea63584d2f7d96aa82d2d996fe294baccb037af9c1fc154c
+size 306494
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_llrd.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_llrd.png
new file mode 100644
index 0000000000000000000000000000000000000000..0678dbe5a6b3e5ee6cd734d43da0761fa8085211
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_llrd.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4c6ba8854e904f9d5662453c2ba45d1ee3efc9f4b3e9b83b1c34ba78f6c06c80
+size 340141
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_lora.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_lora.png
new file mode 100644
index 0000000000000000000000000000000000000000..84afff3a206a6803eef3497c2c7a599f8c2420df
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_lora.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:369bd3153e2e3cfe5c65c20d2b7ff4c72dcaa625d653cb53d33b6ee0e6b4a7fc
+size 236798
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_low_t.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_low_t.png
new file mode 100644
index 0000000000000000000000000000000000000000..a3b5a967b24d0b0e0fb5adf1a1c8943b44b7fefb
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_low_t.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:df3fe897122d26fe2595a737d3b3107142ee29a5a8483cb7c93e4b26f834bb6d
+size 327422
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_mixed_replay.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_mixed_replay.png
new file mode 100644
index 0000000000000000000000000000000000000000..c47a2172737ec58b169525605d0bbe87de150ddd
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_mixed_replay.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c9d2653bb727995f48273b9a349ad5a05503328ba4ca0ddbaf64a436e57b859d
+size 344053
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_normalized_adv.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_normalized_adv.png
new file mode 100644
index 0000000000000000000000000000000000000000..f45ddba21ee52a51e240d2f8e9948b15ff28684c
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_normalized_adv.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:08f0d4bb3e5caa39badf401cafc58cd72589a0087bd6965a852b6e92e20bb0f5
+size 252437
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_reward_filtering.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_reward_filtering.png
new file mode 100644
index 0000000000000000000000000000000000000000..74622513ef82da68cf053695ebb5fc6c2958bcc6
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_reward_filtering.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b25a6398d0202ea5f1ed44231aef0682651e975a4e4fdc5eaf5c7e81d19eb502
+size 327555
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_reward_model.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_reward_model.png
new file mode 100644
index 0000000000000000000000000000000000000000..94abe0730dd5cbdef3ebf19807c014ca22954067
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_reward_model.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0bdfe4260d665dd71af6abdeae55f7f7992e5ac4ec98be8ea2c3f425a3e27b11
+size 337185
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_running_stats.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_running_stats.png
new file mode 100644
index 0000000000000000000000000000000000000000..4d87cb3cedf53c0e09b9b56f76423e98b3a866da
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_running_stats.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:132fa87c7abd362e297b623817465c36ab85c48ae5e965fd58958c78c586c5f7
+size 329614
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_t_curriculum.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_t_curriculum.png
new file mode 100644
index 0000000000000000000000000000000000000000..9df625eaf9d00624d14874833d0c4d0e17e11c19
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_t_curriculum.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cf9f3044a96384420ada27a04b487f6d7c3b58be37f3227f8e38495aac3fb320
+size 340798
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_trust_region_kl.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_trust_region_kl.png
new file mode 100644
index 0000000000000000000000000000000000000000..8f533edfee5119bb4ac87fb8d4abe04b820f5fc1
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_trust_region_kl.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d993bc2e7a32b696c365cc2f8eda086809a863af0ee97b5b74f657658787cca3
+size 334205
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/diagnosis_decision_tree.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/diagnosis_decision_tree.png
new file mode 100644
index 0000000000000000000000000000000000000000..50d1f59f3260eeaae9f65cc2690e4a498469bd01
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/diagnosis_decision_tree.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:919183e2b4d733aaeedade61df932f0b7b3d7ee808d3af20b6b3f81320ac6fcb
+size 104578
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/eval_scores_over_training.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/eval_scores_over_training.png
new file mode 100644
index 0000000000000000000000000000000000000000..e0eb1e13fec7b25e939ba8f8daf20fdbe0f3848d
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/eval_scores_over_training.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e0122400add272c44b2aa9b3659edf931d63192ccb1cb9c338171b5fba388ac3
+size 347896
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/final_score_comparison.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/final_score_comparison.png
new file mode 100644
index 0000000000000000000000000000000000000000..6e9a675069881584e8098abed67eaa51068dbd8b
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/final_score_comparison.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ec17c47b7cb8a29e846b7d403c0894686f6a1d3c4ebd308c69df453b9bd21eab
+size 150907
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/gradient_alignment.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/gradient_alignment.png
new file mode 100644
index 0000000000000000000000000000000000000000..8e3f489294cffc5704593f29d745ce8905806ccc
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/gradient_alignment.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:dfee54a47ebe64124c994e6a10db054b9e6d61a2c9f726c454470b3971c3dea6
+size 248912
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/gradient_conflict_map.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/gradient_conflict_map.png
new file mode 100644
index 0000000000000000000000000000000000000000..0896e519f1badfa414f4faa2556ce8e51b08cd84
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/gradient_conflict_map.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cbdf5c36f4710f63813faedfde96e5ce001685d3dddcf1ab11b8fd0be1abce5d
+size 146081
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/group_comparison.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/group_comparison.png
new file mode 100644
index 0000000000000000000000000000000000000000..7423487b00d8b996f7da2a901a58250c7be58b2d
Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/group_comparison.png differ
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_action_diversity.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_action_diversity.png
new file mode 100644
index 0000000000000000000000000000000000000000..7e1f3907020d6f8f67a89127449ffcf5e6512ad1
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_action_diversity.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9117294129e04b011b00ebdab11715c4afff988968c24d29df87389e8ed124b3
+size 635770
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_advantage_clip.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_advantage_clip.png
new file mode 100644
index 0000000000000000000000000000000000000000..63bb9cee32becbf2c376c38fdad85931d306ef13
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_advantage_clip.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fb3e78245c5051ff9596ca5a8ee020a82cf21c75df8fd42df4f90f0bd387d443
+size 648859
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_attention_only.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_attention_only.png
new file mode 100644
index 0000000000000000000000000000000000000000..e1c982a658c8106fb0adb338e8dbd394f18c8aa0
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_attention_only.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:965213853ede94e3ffd82c132fcf28ea485d55e956cec0560e378b2bcb8943fb
+size 523522
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_baseline_rl.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_baseline_rl.png
new file mode 100644
index 0000000000000000000000000000000000000000..82ddfc80320b77265fa688ad37c57eb1b36f7a35
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_baseline_rl.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e250b22110869a3f2437ce0da897419ce06eeeb4e3fac43137822562dc8f7532
+size 633145
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_bc_wins.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_bc_wins.png
new file mode 100644
index 0000000000000000000000000000000000000000..2f6107cefbae7c8d4e513090ea5d62f960752db8
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_bc_wins.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8fe0cc2ba771668987f1c894f0b6cb97f5a80d6085bfcc1bf4445f4cda296e2a
+size 657126
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_entropy_bonus.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_entropy_bonus.png
new file mode 100644
index 0000000000000000000000000000000000000000..bba2a91b4082a3df508c692d262f73323db272dc
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_entropy_bonus.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b90855ca01ab188a82497a2300f645a27b8afce196e6e9cde8a231af31c8c98e
+size 650095
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_ewc.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_ewc.png
new file mode 100644
index 0000000000000000000000000000000000000000..f4a03ad1fe9e41c93a8c4808028f245e28a37fe3
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_ewc.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:28972e9618898cf93f2faf9056c0e59007ddd024a3c9bc7b46f303f064884b97
+size 620015
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_ffn_only.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_ffn_only.png
new file mode 100644
index 0000000000000000000000000000000000000000..ffe05afa37d06dbe41b1cbe23754c2378688ca99
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_ffn_only.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f71e241620c23f0d0c11d31a688600a9887d6345320e46f8011e1070302f2d90
+size 579344
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_frozen_backbone.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_frozen_backbone.png
new file mode 100644
index 0000000000000000000000000000000000000000..16cf861138b847d7f61d8b21e54f723ac5ca9d0a
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_frozen_backbone.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:90adc39ac2eab073cb04d3901a3f1e82d6d4ab0ea1a28b4cc2ff2af19287b0bd
+size 470394
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_gradient_surgery.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_gradient_surgery.png
new file mode 100644
index 0000000000000000000000000000000000000000..0eb919762fafc34fec20eedb7678942df7fbf7e1
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_gradient_surgery.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:547c6bca9c6059378318db5a3772a9a59b79a2235da093acfa3dc01bba5f10ac
+size 634449
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_head_only.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_head_only.png
new file mode 100644
index 0000000000000000000000000000000000000000..6e70c4ffc50dd44c25a3bb3ef60afcf46440d6c4
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_head_only.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4a135f63a1e1e6c6fa527802fdb7dc072dda40cb1c681464dfc6bc02259d7820
+size 471173
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_kl_penalty.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_kl_penalty.png
new file mode 100644
index 0000000000000000000000000000000000000000..7b1b5d98f8dbd1dda63ffe9ab2c894e4f7e05468
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_kl_penalty.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9e1788e4720fbeac598ba06d70e5bc6239cd5e206350dcac1805d4faf8180675
+size 648697
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_layer_ablation_top1.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_layer_ablation_top1.png
new file mode 100644
index 0000000000000000000000000000000000000000..6090145bf08ee37e0afaf15bec06ea382db210aa
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_layer_ablation_top1.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:85a743053213740f4288e2cf715165d5e1c36d30c35ff540bc3a47c1ac1f7c6a
+size 571169
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_layer_ablation_top2.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_layer_ablation_top2.png
new file mode 100644
index 0000000000000000000000000000000000000000..76b011b88b9d376792c7adc3f195b3833c8d2cca
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_layer_ablation_top2.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b9833dc299acbff783db8ba4ccdd6328c9e5cad6c236b9d53e68933244fff619
+size 595591
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_layer_ablation_top3.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_layer_ablation_top3.png
new file mode 100644
index 0000000000000000000000000000000000000000..5d3697e719cc2d3c0aaa79b52a97aa6bbf7656a3
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_layer_ablation_top3.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ab72de8a763a49b97e3ef0c3f5ac87e4c37559275f6fd716bc125fc37786913e
+size 634144
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_llrd.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_llrd.png
new file mode 100644
index 0000000000000000000000000000000000000000..3e36b27446b60f78ec4ea77212c85f875888b481
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_llrd.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e5adfbb0c7e74d4609063ef1f842345eae97d59f435b4cb37dc179fdd9a4701d
+size 632657
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_lora.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_lora.png
new file mode 100644
index 0000000000000000000000000000000000000000..925c840f0e8f6389aa8d5e6c50465166375007b2
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_lora.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:37b2c0d771f4f849a647a4083f774de47480a388d8c830bfb34db521f2049c1e
+size 562059
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_low_t.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_low_t.png
new file mode 100644
index 0000000000000000000000000000000000000000..57a1cee5da24e9344af5fc142f4c1409ce42290a
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_low_t.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:efb3cd405fd950a3f19991b173178082a5cb75c0af0c32c597af3b92df8b821f
+size 655970
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_mixed_replay.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_mixed_replay.png
new file mode 100644
index 0000000000000000000000000000000000000000..04a0bd9dc44547098146d66d7c2007518b9a0d94
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_mixed_replay.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:869638d9115e97dfd808143d7dd406ee4ed43adb002f2e0a56fd950f9c5fde46
+size 619249
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_normalized_adv.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_normalized_adv.png
new file mode 100644
index 0000000000000000000000000000000000000000..289e87f22f6d071a9723592b7c9ec9e5ee289639
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_normalized_adv.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ee3e4be59cfe42e837d62fa50d507f919afeefd3cba62a4f2650706d1927dbd5
+size 604201
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_reward_filtering.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_reward_filtering.png
new file mode 100644
index 0000000000000000000000000000000000000000..4c2321c56378015ec65b8f58bdb89f9370a56cbd
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_reward_filtering.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:de6e1001eae7d7664c90d9f429f0fe292b78186e8415f677f545f4b3ca544db1
+size 647356
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_reward_model.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_reward_model.png
new file mode 100644
index 0000000000000000000000000000000000000000..fa13e6cbcdcf0707cca32f8a0ff79c40d880aa02
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_reward_model.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9daa06890a2dab01bfe729b493b5bbee6ef393ca02138cfe6acc883039c619ce
+size 630140
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_running_stats.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_running_stats.png
new file mode 100644
index 0000000000000000000000000000000000000000..2c8b46e1fcd2762a88e83d97b2bf9ee1014a9ebb
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_running_stats.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fcb8698ed91abb5c11c2ebac01fa8f02ffd371e400ee7b020bcbf551e406ba9f
+size 663296
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_t_curriculum.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_t_curriculum.png
new file mode 100644
index 0000000000000000000000000000000000000000..66a29584c966ef833178c0a927f99685a871eee7
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_t_curriculum.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8239835c15562278c31f7adce44a595cdbd73493b9892ff191d36d41228cbe4b
+size 664034
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_trust_region_kl.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_trust_region_kl.png
new file mode 100644
index 0000000000000000000000000000000000000000..47fc6a6d10f2b24399e28c63c988eb62f25da9f0
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_trust_region_kl.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8723ec21354878e4104db606715319b8f2f086f75920ee56a8f37fb7b22b79be
+size 608690
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/representation_drift.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/representation_drift.png
new file mode 100644
index 0000000000000000000000000000000000000000..bb3f1feb52b2c2c4840aa8aaf1d6c6c0fb28cd79
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/representation_drift.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6547c1be5f274793b131c9df4ffab650237e78591f480dc53b79f87bb50e6370
+size 209309
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/score_delta_over_baseline_rl.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/score_delta_over_baseline_rl.png
new file mode 100644
index 0000000000000000000000000000000000000000..9b2d2653aa18516117c2967b49387c12d790a05f
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/score_delta_over_baseline_rl.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4a71e3d5e3a6a7d24fe06e763676821a23716d070bf5f6ad791d98f2e8eafb26
+size 134007
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_action_diversity.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_action_diversity.png
new file mode 100644
index 0000000000000000000000000000000000000000..d0f5ae6df6ceaac228fd35a3ef30c10bf47043cf
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_action_diversity.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:007ace3dad75964a45c593569f5f8f4ee58f49093358f701c5075f2828cb90ec
+size 187940
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_advantage_clip.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_advantage_clip.png
new file mode 100644
index 0000000000000000000000000000000000000000..5727e6c1575d9f327faca7af66a545b70fd41c48
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_advantage_clip.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f055d32819cd7592d2c168bc93c121fb0b4f262bfe303d2902570cb2db905f0a
+size 192830
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_attention_only.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_attention_only.png
new file mode 100644
index 0000000000000000000000000000000000000000..9847688ad3e88ae123cdc4df9f2bd6742c468ade
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_attention_only.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:53f167b6fb25034bc2e6ee9848e7a0bef6a22c24d9efa19d3f735c210b9fa24e
+size 195455
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_baseline_rl.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_baseline_rl.png
new file mode 100644
index 0000000000000000000000000000000000000000..fa467a45ded4481c919bc3b3817590ea69e623b2
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_baseline_rl.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:532e3f8761399ea366804661f148b65e1f415c272d909c5365bc2694e5670275
+size 186133
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_bc_wins.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_bc_wins.png
new file mode 100644
index 0000000000000000000000000000000000000000..0ed4fad94c30b3f74b66408975ae423a1bf1e1e6
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_bc_wins.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0616e3abc75b57a5bd75d6c403849f4a8f16fc24e6cea08d85a499496a883bf5
+size 153627
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_entropy_bonus.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_entropy_bonus.png
new file mode 100644
index 0000000000000000000000000000000000000000..683df4c4b281171b0fc77e79528663cd5944eba0
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_entropy_bonus.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:803f74d13538b199b36a734774543922fef2ddf88699474e98a093c397e60118
+size 183155
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_ewc.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_ewc.png
new file mode 100644
index 0000000000000000000000000000000000000000..d1a2f3cd3b1f4ad5467f7bbb25618e2f9578006e
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_ewc.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:05ca0467d3717181c746f3e2ed1aa3f71bdf347d7162d44ddaacce4527d8bb92
+size 175387
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_ffn_only.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_ffn_only.png
new file mode 100644
index 0000000000000000000000000000000000000000..9f430954dae038ff0f5ad79d4563fcecd6161869
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_ffn_only.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:13e519ebf3e30e8e9b55db8d30280af3a5a4e783d11c438825941c133b189f28
+size 215131
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_frozen_backbone.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_frozen_backbone.png
new file mode 100644
index 0000000000000000000000000000000000000000..edd37da56856ab400728b58655b653f916b0b9c3
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_frozen_backbone.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:785fd8bf211f93e1f07289b60d29d5f4fd253063eda5338bfb673da3823807eb
+size 187307
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_gradient_surgery.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_gradient_surgery.png
new file mode 100644
index 0000000000000000000000000000000000000000..ff21f64911b019a38c5c0f8dcfa87363a3787ef7
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_gradient_surgery.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7fc39fddd3b6cd3e39768f7241f1ebc22252dd17d96f28e5aeb87f37c1dd68f6
+size 187654
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_head_only.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_head_only.png
new file mode 100644
index 0000000000000000000000000000000000000000..6ebc5fe6706a3818a0b2b06451d58e8ed9f14bda
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_head_only.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f988502a78935585b3b04c9604b65ee0727ccf7f9bf3d4819e8e2612972d16dd
+size 187162
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_kl_penalty.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_kl_penalty.png
new file mode 100644
index 0000000000000000000000000000000000000000..23c979b88cbd4fe6ae80c5fd3e7f41554fdabb52
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_kl_penalty.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8dc72f23b90d98ce82a4411dba2c4affa3c0f06046d127c3e3098dbccb60ca69
+size 184990
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_layer_ablation_top1.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_layer_ablation_top1.png
new file mode 100644
index 0000000000000000000000000000000000000000..7f258e7d84f162effbc1738fc062f68036b2d534
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_layer_ablation_top1.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ead65ae8815f6025348d58f8f70f94ebca866cad91fd11a3be0f24646037d83e
+size 218911
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_layer_ablation_top2.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_layer_ablation_top2.png
new file mode 100644
index 0000000000000000000000000000000000000000..b1b3d3c1643d969bd97bbb51a16e2966c17470fc
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_layer_ablation_top2.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:17ce988ee085b339c0960df2818428e15ad81abbc2216b118fcc27700ad6c738
+size 227174
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_layer_ablation_top3.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_layer_ablation_top3.png
new file mode 100644
index 0000000000000000000000000000000000000000..f0d33d31afcfeb8b0a31dc4ab180fe1e55a2d63a
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_layer_ablation_top3.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:24aafd1030168a62bda314878cb45e5b71fa10b7b1b52af789c34f0d7990b9bb
+size 219898
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_llrd.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_llrd.png
new file mode 100644
index 0000000000000000000000000000000000000000..c756680f5978d6888cbff9393c13ee2257a9d3ab
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_llrd.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2e632294988ca1b860a032a109b88dabe01fd355a16731166f911304e6c6f077
+size 186632
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_lora.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_lora.png
new file mode 100644
index 0000000000000000000000000000000000000000..268634acbbd05cec6a7520fd5891ed2cc682247c
Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_lora.png differ
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_low_t.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_low_t.png
new file mode 100644
index 0000000000000000000000000000000000000000..4ab0c363518ecff97a596985a4fd9458b0fe3e10
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_low_t.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c000b09dd84f6cdd71d828d251ddb96317da7efda6f180fb5249275eaffd5cde
+size 178303
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_mixed_replay.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_mixed_replay.png
new file mode 100644
index 0000000000000000000000000000000000000000..d95cf109345620af88462185399e0877049f8771
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_mixed_replay.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:85d8d0c6bcee0597c5e1c559f1be45c8170759e48ef34833c5b3bcad4f8458fe
+size 176911
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_normalized_adv.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_normalized_adv.png
new file mode 100644
index 0000000000000000000000000000000000000000..ea168c31e9cda0a51270ea1237f7ff014dbc9693
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_normalized_adv.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fa9cf827f4bb35145bb6bcbfc0f0eca38e157db3895df881a4b0f6a1876f21fe
+size 260526
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_reward_filtering.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_reward_filtering.png
new file mode 100644
index 0000000000000000000000000000000000000000..8c7c68a7d33bdbd9100a6e3ef0cb0a7e70750386
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_reward_filtering.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ea40744f368d168d2127c62d695714e2621866b759787b5a91957030387bf477
+size 183886
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_reward_model.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_reward_model.png
new file mode 100644
index 0000000000000000000000000000000000000000..5072c9556f449461a823792b1cdf0e8c2fe1786c
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_reward_model.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:09b8a6bcadac61b4a27ae93ecadfb9c35433bbffcf4dc549f98c331cf5613850
+size 165716
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_running_stats.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_running_stats.png
new file mode 100644
index 0000000000000000000000000000000000000000..2144b4779e22535956ffe16c3492d6f3f288eb80
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_running_stats.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:92c36eb53fdde01f3f59e39eeb265b97d4d86a4d7b7a64f5c6157a2891c7e4d6
+size 173606
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_t_curriculum.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_t_curriculum.png
new file mode 100644
index 0000000000000000000000000000000000000000..60b8af4bd413b0c7bbfa937904b3204c0803cd01
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_t_curriculum.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:515261a5f13dcf5753e8fa8e37bd50f19adae673c949ea9f55eb17cd8362b715
+size 173067
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_trust_region_kl.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_trust_region_kl.png
new file mode 100644
index 0000000000000000000000000000000000000000..05e623cce9702f9e7a1b2d7b20d52de1039572e1
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_trust_region_kl.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a023d1d969bf7a88930905249dd82839069fbfe943161984401f857f79bfb746
+size 182990
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_norms_heatmap.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_norms_heatmap.png
new file mode 100644
index 0000000000000000000000000000000000000000..32674032b759041a6df618043b79c21613f8d829
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_norms_heatmap.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e9645f05d86a8a02d063f550eb79c3bbd153b31d615e579b120f824fcac8fb20
+size 136164
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_distribution_analysis.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_distribution_analysis.png
new file mode 100644
index 0000000000000000000000000000000000000000..8c76f8bf3993df88823eea3902c94def755b8166
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_distribution_analysis.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:374b51477ed71e52770a6c9b991033808cae04960e3a92b5168e085eed31499a
+size 347499
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/win_rate_and_effective_batch_size.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/win_rate_and_effective_batch_size.png
new file mode 100644
index 0000000000000000000000000000000000000000..abee4a1459e12ef75bff95b0d1032cc3d6c5e744
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/win_rate_and_effective_batch_size.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:403950780d101e572ae75a05a5a71abc25070078ab3d36b565486b8649e42642
+size 527666
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/achievement_summary.csv b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/achievement_summary.csv
new file mode 100644
index 0000000000000000000000000000000000000000..d86c9ef11d5974d3a7548d4bfd6dfc144bb60ecf
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/achievement_summary.csv
@@ -0,0 +1,23 @@
+Achievement,Pretrained,action_diversity,delta_action_diversity,advantage_clip,delta_advantage_clip,attention_only,delta_attention_only,baseline_rl,delta_baseline_rl,bc_wins,delta_bc_wins,entropy_bonus,delta_entropy_bonus,ewc,delta_ewc,ffn_only,delta_ffn_only,frozen_backbone,delta_frozen_backbone,gradient_surgery,delta_gradient_surgery,head_only,delta_head_only,kl_penalty,delta_kl_penalty,layer_ablation_top1,delta_layer_ablation_top1,layer_ablation_top2,delta_layer_ablation_top2,layer_ablation_top3,delta_layer_ablation_top3,llrd,delta_llrd,lora,delta_lora,low_t,delta_low_t,mixed_replay,delta_mixed_replay,normalized_adv,delta_normalized_adv,reward_filtering,delta_reward_filtering,reward_model,delta_reward_model,running_stats,delta_running_stats,t_curriculum,delta_t_curriculum,trust_region_kl,delta_trust_region_kl
+Achievements/collect_coal,0.4857,0.5079,0.0222,0.4338,-0.0519,0.0,-0.4857,0.4429,-0.0429,0.4643,-0.0214,0.4255,-0.0602,0.5407,0.055,0.0303,-0.4554,0.0,-0.4857,0.4118,-0.0739,0.0,-0.4857,0.4815,-0.0042,0.1702,-0.3155,0.2806,-0.2051,0.1812,-0.3046,0.4468,-0.0389,0.0,-0.4857,0.4085,-0.0773,0.5391,0.0533,0.084,-0.4017,0.4317,-0.0541,0.4615,-0.0242,0.5074,0.0216,0.4397,-0.046,0.4074,-0.0783
+Achievements/collect_diamond,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0074,0.0074,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
+Achievements/collect_drink,0.3857,0.4286,0.0429,0.4044,0.0187,0.0,-0.3857,0.3857,0.0,0.3714,-0.0143,0.3333,-0.0524,0.437,0.0513,0.2848,-0.1009,0.0,-0.3857,0.4191,0.0334,0.0,-0.3857,0.3259,-0.0598,0.3333,-0.0524,0.4245,0.0387,0.4203,0.0346,0.4043,0.0185,0.0,-0.3857,0.3944,0.0087,0.4141,0.0283,0.1849,-0.2008,0.3453,-0.0404,0.3427,-0.0431,0.3676,-0.0181,0.3972,0.0114,0.4,0.0143
+Achievements/collect_iron,0.3357,0.3095,-0.0262,0.3309,-0.0048,0.0,-0.3357,0.3786,0.0429,0.3571,0.0214,0.2979,-0.0378,0.3556,0.0198,0.0,-0.3357,0.0,-0.3357,0.3088,-0.0269,0.0,-0.3357,0.3704,0.0347,0.0213,-0.3144,0.1151,-0.2206,0.0652,-0.2705,0.2979,-0.0378,0.0,-0.3357,0.2676,-0.0681,0.4062,0.0705,0.0084,-0.3273,0.4029,0.0672,0.3427,0.0069,0.4118,0.0761,0.3546,0.0189,0.2815,-0.0542
+Achievements/collect_sapling,0.9071,0.9127,0.0056,0.9265,0.0193,0.0,-0.9071,0.8643,-0.0429,0.8571,-0.05,0.9362,0.029,0.9037,-0.0034,0.5576,-0.3496,0.0,-0.9071,0.9191,0.012,0.0,-0.9071,0.9333,0.0262,0.7801,-0.127,0.8993,-0.0079,0.8623,-0.0448,0.8936,-0.0135,0.0,-0.9071,0.9296,0.0224,0.9141,0.0069,1.0,0.0929,0.9065,-0.0007,0.9021,-0.005,0.8897,-0.0174,0.9149,0.0078,0.8889,-0.0183
+Achievements/collect_stone,0.9,0.8651,-0.0349,0.8824,-0.0176,0.0,-0.9,0.8929,-0.0071,0.8929,-0.0071,0.8936,-0.0064,0.9037,0.0037,0.1576,-0.7424,0.0,-0.9,0.8603,-0.0397,0.0,-0.9,0.8815,-0.0185,0.5106,-0.3894,0.5899,-0.3101,0.6522,-0.2478,0.8652,-0.0348,0.0,-0.9,0.8873,-0.0127,0.9141,0.0141,0.3025,-0.5975,0.9065,0.0065,0.8951,-0.0049,0.9044,0.0044,0.8865,-0.0135,0.9333,0.0333
+Achievements/collect_wood,1.0,0.9921,-0.0079,0.9926,-0.0074,0.0,-1.0,0.9929,-0.0071,0.9929,-0.0071,0.9858,-0.0142,1.0,0.0,0.9212,-0.0788,0.0,-1.0,1.0,0.0,0.0,-1.0,1.0,0.0,1.0,0.0,1.0,0.0,0.9928,-0.0072,0.9929,-0.0071,0.0,-1.0,0.9859,-0.0141,1.0,0.0,0.958,-0.042,0.9928,-0.0072,1.0,0.0,0.9926,-0.0074,0.9929,-0.0071,1.0,0.0
+Achievements/defeat_skeleton,0.1571,0.254,0.0968,0.2279,0.0708,0.0,-0.1571,0.2143,0.0571,0.2071,0.05,0.1844,0.0273,0.2,0.0429,0.0242,-0.1329,0.0,-0.1571,0.2132,0.0561,0.0,-0.1571,0.2296,0.0725,0.078,-0.0791,0.0791,-0.078,0.1087,-0.0484,0.0993,-0.0579,0.0,-0.1571,0.1972,0.04,0.2266,0.0694,0.0924,-0.0647,0.2446,0.0875,0.2028,0.0457,0.1838,0.0267,0.2553,0.0982,0.2296,0.0725
+Achievements/defeat_zombie,0.6214,0.6508,0.0294,0.5956,-0.0258,0.0,-0.6214,0.7214,0.1,0.5714,-0.05,0.695,0.0736,0.6815,0.0601,0.2606,-0.3608,0.0,-0.6214,0.6912,0.0697,0.0,-0.6214,0.6519,0.0304,0.5887,-0.0328,0.5971,-0.0243,0.6884,0.067,0.6879,0.0665,0.0,-0.6214,0.6761,0.0546,0.6484,0.027,0.6134,-0.008,0.6906,0.0692,0.6643,0.0429,0.6397,0.0183,0.6454,0.024,0.6815,0.0601
+Achievements/eat_cow,0.3357,0.4603,0.1246,0.3603,0.0246,0.0,-0.3357,0.3429,0.0071,0.3786,0.0429,0.383,0.0473,0.3333,-0.0024,0.1152,-0.2206,0.0,-0.3357,0.375,0.0393,0.0,-0.3357,0.3926,0.0569,0.2057,-0.13,0.2446,-0.0911,0.2971,-0.0386,0.2908,-0.0449,0.0,-0.3357,0.3732,0.0375,0.4219,0.0862,0.2437,-0.092,0.3022,-0.0336,0.3427,0.0069,0.375,0.0393,0.305,-0.0307,0.363,0.0272
+Achievements/eat_plant,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
+Achievements/make_iron_pickaxe,0.0214,0.0238,0.0024,0.0221,0.0006,0.0,-0.0214,0.0214,0.0,0.0,-0.0214,0.0284,0.0069,0.037,0.0156,0.0,-0.0214,0.0,-0.0214,0.0074,-0.0141,0.0,-0.0214,0.0222,0.0008,0.0,-0.0214,0.0144,-0.007,0.0072,-0.0142,0.0213,-0.0002,0.0,-0.0214,0.0141,-0.0073,0.0078,-0.0136,0.0,-0.0214,0.0504,0.0289,0.014,-0.0074,0.0368,0.0153,0.0213,-0.0002,0.0,-0.0214
+Achievements/make_iron_sword,0.1214,0.1349,0.0135,0.125,0.0036,0.0,-0.1214,0.1571,0.0357,0.1571,0.0357,0.1206,-0.0009,0.1259,0.0045,0.0,-0.1214,0.0,-0.1214,0.0956,-0.0258,0.0,-0.1214,0.1778,0.0563,0.0,-0.1214,0.0288,-0.0927,0.0,-0.1214,0.0922,-0.0292,0.0,-0.1214,0.0845,-0.0369,0.1875,0.0661,0.0,-0.1214,0.1367,0.0153,0.1538,0.0324,0.1471,0.0256,0.1277,0.0062,0.0815,-0.0399
+Achievements/make_stone_pickaxe,0.65,0.6746,0.0246,0.6029,-0.0471,0.0,-0.65,0.6714,0.0214,0.6929,0.0429,0.6525,0.0025,0.7407,0.0907,0.0242,-0.6258,0.0,-0.65,0.6838,0.0338,0.0,-0.65,0.6815,0.0315,0.1418,-0.5082,0.3525,-0.2975,0.2971,-0.3529,0.6525,0.0025,0.0,-0.65,0.6408,-0.0092,0.6875,0.0375,0.0252,-0.6248,0.7122,0.0622,0.6993,0.0493,0.6838,0.0338,0.6454,-0.0046,0.7111,0.0611
+Achievements/make_stone_sword,0.7286,0.7302,0.0016,0.7353,0.0067,0.0,-0.7286,0.7214,-0.0071,0.75,0.0214,0.766,0.0374,0.7259,-0.0026,0.0545,-0.674,0.0,-0.7286,0.6765,-0.0521,0.0,-0.7286,0.7185,-0.0101,0.2482,-0.4803,0.3669,-0.3617,0.3696,-0.359,0.7447,0.0161,0.0,-0.7286,0.7465,0.0179,0.7656,0.0371,0.0672,-0.6613,0.7986,0.07,0.6923,-0.0363,0.7279,-0.0006,0.7021,-0.0264,0.763,0.0344
+Achievements/make_wood_pickaxe,0.9357,0.9444,0.0087,0.9265,-0.0092,0.0,-0.9357,0.9286,-0.0071,0.9571,0.0214,0.9362,0.0005,0.9556,0.0198,0.3636,-0.5721,0.0,-0.9357,0.9412,0.0055,0.0,-0.9357,0.9481,0.0124,0.695,-0.2407,0.7554,-0.1803,0.7971,-0.1386,0.922,-0.0137,0.0,-0.9357,0.9155,-0.0202,0.9297,-0.006,0.3445,-0.5912,0.9353,-0.0005,0.9301,-0.0056,0.9412,0.0055,0.922,-0.0137,0.9481,0.0124
+Achievements/make_wood_sword,0.6714,0.6587,-0.0127,0.6765,0.005,0.0,-0.6714,0.6786,0.0071,0.7357,0.0643,0.6879,0.0165,0.6667,-0.0048,0.3091,-0.3623,0.0,-0.6714,0.7132,0.0418,0.0,-0.6714,0.6889,0.0175,0.6596,-0.0119,0.7554,0.084,0.7971,0.1257,0.6738,0.0023,0.0,-0.6714,0.7254,0.0539,0.6875,0.0161,0.1176,-0.5538,0.6978,0.0264,0.6154,-0.056,0.75,0.0786,0.6738,0.0023,0.7481,0.0767
+Achievements/place_furnace,0.8,0.7778,-0.0222,0.8309,0.0309,0.0,-0.8,0.8214,0.0214,0.8214,0.0214,0.8156,0.0156,0.8148,0.0148,0.0303,-0.7697,0.0,-0.8,0.7574,-0.0426,0.0,-0.8,0.8,-0.0,0.1986,-0.6014,0.4388,-0.3612,0.4565,-0.3435,0.7872,-0.0128,0.0,-0.8,0.7958,-0.0042,0.7969,-0.0031,0.2017,-0.5983,0.8417,0.0417,0.7902,-0.0098,0.8088,0.0088,0.8298,0.0298,0.8,-0.0
+Achievements/place_plant,0.5929,0.6508,0.0579,0.7059,0.113,0.0,-0.5929,0.5786,-0.0143,0.6286,0.0357,0.695,0.1022,0.6519,0.059,0.0242,-0.5686,0.0,-0.5929,0.6324,0.0395,0.0,-0.5929,0.6593,0.0664,0.0426,-0.5503,0.0791,-0.5137,0.1014,-0.4914,0.6879,0.0951,0.0,-0.5929,0.6901,0.0973,0.6719,0.079,0.9412,0.3483,0.6906,0.0978,0.7063,0.1134,0.625,0.0321,0.6667,0.0738,0.6519,0.059
+Achievements/place_stone,0.6429,0.6429,0.0,0.7279,0.0851,0.0,-0.6429,0.7143,0.0714,0.7071,0.0643,0.695,0.0522,0.7185,0.0757,0.0485,-0.5944,0.0,-0.6429,0.6397,-0.0032,0.0,-0.6429,0.6444,0.0016,0.227,-0.4159,0.2806,-0.3623,0.2681,-0.3747,0.6879,0.0451,0.0,-0.6429,0.7254,0.0825,0.7187,0.0759,0.1261,-0.5168,0.6259,-0.017,0.6643,0.0215,0.6103,-0.0326,0.695,0.0522,0.7333,0.0905
+Achievements/place_table,0.9786,0.9683,-0.0103,0.9853,0.0067,0.0,-0.9786,0.9571,-0.0214,0.9643,-0.0143,0.9433,-0.0353,0.9778,-0.0008,0.4788,-0.4998,0.0,-0.9786,0.9853,0.0067,0.0,-0.9786,0.9704,-0.0082,0.8014,-0.1772,0.8561,-0.1225,0.8623,-0.1163,0.9504,-0.0282,0.0,-0.9786,0.9507,-0.0279,0.9922,0.0136,0.4538,-0.5248,0.9712,-0.0073,0.958,-0.0205,0.9779,-0.0006,0.9645,-0.014,1.0,0.0214
+Achievements/wake_up,0.0929,0.119,0.0262,0.0735,-0.0193,0.0,-0.0929,0.0786,-0.0143,0.1,0.0071,0.0709,-0.0219,0.0963,0.0034,0.0061,-0.0868,0.0,-0.0929,0.1324,0.0395,0.0,-0.0929,0.1481,0.0553,0.2199,0.127,0.1583,0.0654,0.2391,0.1463,0.1064,0.0135,0.0,-0.0929,0.1197,0.0269,0.0547,-0.0382,0.0,-0.0929,0.0863,-0.0065,0.0699,-0.0229,0.0809,-0.012,0.0496,-0.0432,0.1185,0.0257
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/achievement_summary.tex b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/achievement_summary.tex
new file mode 100644
index 0000000000000000000000000000000000000000..4e2eee67b80df22d3ee934921d6f8c62ee4f57d7
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/achievement_summary.tex
@@ -0,0 +1,33 @@
+\begin{table}[htbp]
+\centering
+\caption{Per-achievement final unlock rates and delta vs pretrained.}
+\label{tab:achievements}
+\begin{tabular}{lrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrr}
+\toprule
+\textbf{Achievement} & \textbf{Pretrained} & \textbf{action\_diversity} & \textbf{delta\_action\_diversity} & \textbf{advantage\_clip} & \textbf{delta\_advantage\_clip} & \textbf{attention\_only} & \textbf{delta\_attention\_only} & \textbf{baseline\_rl} & \textbf{delta\_baseline\_rl} & \textbf{bc\_wins} & \textbf{delta\_bc\_wins} & \textbf{entropy\_bonus} & \textbf{delta\_entropy\_bonus} & \textbf{ewc} & \textbf{delta\_ewc} & \textbf{ffn\_only} & \textbf{delta\_ffn\_only} & \textbf{frozen\_backbone} & \textbf{delta\_frozen\_backbone} & \textbf{gradient\_surgery} & \textbf{delta\_gradient\_surgery} & \textbf{head\_only} & \textbf{delta\_head\_only} & \textbf{kl\_penalty} & \textbf{delta\_kl\_penalty} & \textbf{layer\_ablation\_top1} & \textbf{delta\_layer\_ablation\_top1} & \textbf{layer\_ablation\_top2} & \textbf{delta\_layer\_ablation\_top2} & \textbf{layer\_ablation\_top3} & \textbf{delta\_layer\_ablation\_top3} & \textbf{llrd} & \textbf{delta\_llrd} & \textbf{lora} & \textbf{delta\_lora} & \textbf{low\_t} & \textbf{delta\_low\_t} & \textbf{mixed\_replay} & \textbf{delta\_mixed\_replay} & \textbf{normalized\_adv} & \textbf{delta\_normalized\_adv} & \textbf{reward\_filtering} & \textbf{delta\_reward\_filtering} & \textbf{reward\_model} & \textbf{delta\_reward\_model} & \textbf{running\_stats} & \textbf{delta\_running\_stats} & \textbf{t\_curriculum} & \textbf{delta\_t\_curriculum} & \textbf{trust\_region\_kl} & \textbf{delta\_trust\_region\_kl} \\
+\midrule
+Achievements/collect\_coal & 0.4857 & 0.5079 & 0.0222 & 0.4338 & -0.0519 & 0.0000 & -0.4857 & 0.4429 & -0.0429 & 0.4643 & -0.0214 & 0.4255 & -0.0602 & 0.5407 & 0.0550 & 0.0303 & -0.4554 & 0.0000 & -0.4857 & 0.4118 & -0.0739 & 0.0000 & -0.4857 & 0.4815 & -0.0042 & 0.1702 & -0.3155 & 0.2806 & -0.2051 & 0.1812 & -0.3046 & 0.4468 & -0.0389 & 0.0000 & -0.4857 & 0.4085 & -0.0773 & 0.5391 & 0.0533 & 0.0840 & -0.4017 & 0.4317 & -0.0541 & 0.4615 & -0.0242 & 0.5074 & 0.0216 & 0.4397 & -0.0460 & 0.4074 & -0.0783 \\
+Achievements/collect\_diamond & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0074 & 0.0074 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 \\
+Achievements/collect\_drink & 0.3857 & 0.4286 & 0.0429 & 0.4044 & 0.0187 & 0.0000 & -0.3857 & 0.3857 & 0.0000 & 0.3714 & -0.0143 & 0.3333 & -0.0524 & 0.4370 & 0.0513 & 0.2848 & -0.1009 & 0.0000 & -0.3857 & 0.4191 & 0.0334 & 0.0000 & -0.3857 & 0.3259 & -0.0598 & 0.3333 & -0.0524 & 0.4245 & 0.0387 & 0.4203 & 0.0346 & 0.4043 & 0.0185 & 0.0000 & -0.3857 & 0.3944 & 0.0087 & 0.4141 & 0.0283 & 0.1849 & -0.2008 & 0.3453 & -0.0404 & 0.3427 & -0.0431 & 0.3676 & -0.0181 & 0.3972 & 0.0114 & 0.4000 & 0.0143 \\
+Achievements/collect\_iron & 0.3357 & 0.3095 & -0.0262 & 0.3309 & -0.0048 & 0.0000 & -0.3357 & 0.3786 & 0.0429 & 0.3571 & 0.0214 & 0.2979 & -0.0378 & 0.3556 & 0.0198 & 0.0000 & -0.3357 & 0.0000 & -0.3357 & 0.3088 & -0.0269 & 0.0000 & -0.3357 & 0.3704 & 0.0347 & 0.0213 & -0.3144 & 0.1151 & -0.2206 & 0.0652 & -0.2705 & 0.2979 & -0.0378 & 0.0000 & -0.3357 & 0.2676 & -0.0681 & 0.4062 & 0.0705 & 0.0084 & -0.3273 & 0.4029 & 0.0672 & 0.3427 & 0.0069 & 0.4118 & 0.0761 & 0.3546 & 0.0189 & 0.2815 & -0.0542 \\
+Achievements/collect\_sapling & 0.9071 & 0.9127 & 0.0056 & 0.9265 & 0.0193 & 0.0000 & -0.9071 & 0.8643 & -0.0429 & 0.8571 & -0.0500 & 0.9362 & 0.0290 & 0.9037 & -0.0034 & 0.5576 & -0.3496 & 0.0000 & -0.9071 & 0.9191 & 0.0120 & 0.0000 & -0.9071 & 0.9333 & 0.0262 & 0.7801 & -0.1270 & 0.8993 & -0.0079 & 0.8623 & -0.0448 & 0.8936 & -0.0135 & 0.0000 & -0.9071 & 0.9296 & 0.0224 & 0.9141 & 0.0069 & 1.0000 & 0.0929 & 0.9065 & -0.0007 & 0.9021 & -0.0050 & 0.8897 & -0.0174 & 0.9149 & 0.0078 & 0.8889 & -0.0183 \\
+Achievements/collect\_stone & 0.9000 & 0.8651 & -0.0349 & 0.8824 & -0.0176 & 0.0000 & -0.9000 & 0.8929 & -0.0071 & 0.8929 & -0.0071 & 0.8936 & -0.0064 & 0.9037 & 0.0037 & 0.1576 & -0.7424 & 0.0000 & -0.9000 & 0.8603 & -0.0397 & 0.0000 & -0.9000 & 0.8815 & -0.0185 & 0.5106 & -0.3894 & 0.5899 & -0.3101 & 0.6522 & -0.2478 & 0.8652 & -0.0348 & 0.0000 & -0.9000 & 0.8873 & -0.0127 & 0.9141 & 0.0141 & 0.3025 & -0.5975 & 0.9065 & 0.0065 & 0.8951 & -0.0049 & 0.9044 & 0.0044 & 0.8865 & -0.0135 & 0.9333 & 0.0333 \\
+Achievements/collect\_wood & 1.0000 & 0.9921 & -0.0079 & 0.9926 & -0.0074 & 0.0000 & -1.0000 & 0.9929 & -0.0071 & 0.9929 & -0.0071 & 0.9858 & -0.0142 & 1.0000 & 0.0000 & 0.9212 & -0.0788 & 0.0000 & -1.0000 & 1.0000 & 0.0000 & 0.0000 & -1.0000 & 1.0000 & 0.0000 & 1.0000 & 0.0000 & 1.0000 & 0.0000 & 0.9928 & -0.0072 & 0.9929 & -0.0071 & 0.0000 & -1.0000 & 0.9859 & -0.0141 & 1.0000 & 0.0000 & 0.9580 & -0.0420 & 0.9928 & -0.0072 & 1.0000 & 0.0000 & 0.9926 & -0.0074 & 0.9929 & -0.0071 & 1.0000 & 0.0000 \\
+Achievements/defeat\_skeleton & 0.1571 & 0.2540 & 0.0968 & 0.2279 & 0.0708 & 0.0000 & -0.1571 & 0.2143 & 0.0571 & 0.2071 & 0.0500 & 0.1844 & 0.0273 & 0.2000 & 0.0429 & 0.0242 & -0.1329 & 0.0000 & -0.1571 & 0.2132 & 0.0561 & 0.0000 & -0.1571 & 0.2296 & 0.0725 & 0.0780 & -0.0791 & 0.0791 & -0.0780 & 0.1087 & -0.0484 & 0.0993 & -0.0579 & 0.0000 & -0.1571 & 0.1972 & 0.0400 & 0.2266 & 0.0694 & 0.0924 & -0.0647 & 0.2446 & 0.0875 & 0.2028 & 0.0457 & 0.1838 & 0.0267 & 0.2553 & 0.0982 & 0.2296 & 0.0725 \\
+Achievements/defeat\_zombie & 0.6214 & 0.6508 & 0.0294 & 0.5956 & -0.0258 & 0.0000 & -0.6214 & 0.7214 & 0.1000 & 0.5714 & -0.0500 & 0.6950 & 0.0736 & 0.6815 & 0.0601 & 0.2606 & -0.3608 & 0.0000 & -0.6214 & 0.6912 & 0.0697 & 0.0000 & -0.6214 & 0.6519 & 0.0304 & 0.5887 & -0.0328 & 0.5971 & -0.0243 & 0.6884 & 0.0670 & 0.6879 & 0.0665 & 0.0000 & -0.6214 & 0.6761 & 0.0546 & 0.6484 & 0.0270 & 0.6134 & -0.0080 & 0.6906 & 0.0692 & 0.6643 & 0.0429 & 0.6397 & 0.0183 & 0.6454 & 0.0240 & 0.6815 & 0.0601 \\
+Achievements/eat\_cow & 0.3357 & 0.4603 & 0.1246 & 0.3603 & 0.0246 & 0.0000 & -0.3357 & 0.3429 & 0.0071 & 0.3786 & 0.0429 & 0.3830 & 0.0473 & 0.3333 & -0.0024 & 0.1152 & -0.2206 & 0.0000 & -0.3357 & 0.3750 & 0.0393 & 0.0000 & -0.3357 & 0.3926 & 0.0569 & 0.2057 & -0.1300 & 0.2446 & -0.0911 & 0.2971 & -0.0386 & 0.2908 & -0.0449 & 0.0000 & -0.3357 & 0.3732 & 0.0375 & 0.4219 & 0.0862 & 0.2437 & -0.0920 & 0.3022 & -0.0336 & 0.3427 & 0.0069 & 0.3750 & 0.0393 & 0.3050 & -0.0307 & 0.3630 & 0.0272 \\
+Achievements/eat\_plant & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 \\
+Achievements/make\_iron\_pickaxe & 0.0214 & 0.0238 & 0.0024 & 0.0221 & 0.0006 & 0.0000 & -0.0214 & 0.0214 & 0.0000 & 0.0000 & -0.0214 & 0.0284 & 0.0069 & 0.0370 & 0.0156 & 0.0000 & -0.0214 & 0.0000 & -0.0214 & 0.0074 & -0.0141 & 0.0000 & -0.0214 & 0.0222 & 0.0008 & 0.0000 & -0.0214 & 0.0144 & -0.0070 & 0.0072 & -0.0142 & 0.0213 & -0.0002 & 0.0000 & -0.0214 & 0.0141 & -0.0073 & 0.0078 & -0.0136 & 0.0000 & -0.0214 & 0.0504 & 0.0289 & 0.0140 & -0.0074 & 0.0368 & 0.0153 & 0.0213 & -0.0002 & 0.0000 & -0.0214 \\
+Achievements/make\_iron\_sword & 0.1214 & 0.1349 & 0.0135 & 0.1250 & 0.0036 & 0.0000 & -0.1214 & 0.1571 & 0.0357 & 0.1571 & 0.0357 & 0.1206 & -0.0009 & 0.1259 & 0.0045 & 0.0000 & -0.1214 & 0.0000 & -0.1214 & 0.0956 & -0.0258 & 0.0000 & -0.1214 & 0.1778 & 0.0563 & 0.0000 & -0.1214 & 0.0288 & -0.0927 & 0.0000 & -0.1214 & 0.0922 & -0.0292 & 0.0000 & -0.1214 & 0.0845 & -0.0369 & 0.1875 & 0.0661 & 0.0000 & -0.1214 & 0.1367 & 0.0153 & 0.1538 & 0.0324 & 0.1471 & 0.0256 & 0.1277 & 0.0062 & 0.0815 & -0.0399 \\
+Achievements/make\_stone\_pickaxe & 0.6500 & 0.6746 & 0.0246 & 0.6029 & -0.0471 & 0.0000 & -0.6500 & 0.6714 & 0.0214 & 0.6929 & 0.0429 & 0.6525 & 0.0025 & 0.7407 & 0.0907 & 0.0242 & -0.6258 & 0.0000 & -0.6500 & 0.6838 & 0.0338 & 0.0000 & -0.6500 & 0.6815 & 0.0315 & 0.1418 & -0.5082 & 0.3525 & -0.2975 & 0.2971 & -0.3529 & 0.6525 & 0.0025 & 0.0000 & -0.6500 & 0.6408 & -0.0092 & 0.6875 & 0.0375 & 0.0252 & -0.6248 & 0.7122 & 0.0622 & 0.6993 & 0.0493 & 0.6838 & 0.0338 & 0.6454 & -0.0046 & 0.7111 & 0.0611 \\
+Achievements/make\_stone\_sword & 0.7286 & 0.7302 & 0.0016 & 0.7353 & 0.0067 & 0.0000 & -0.7286 & 0.7214 & -0.0071 & 0.7500 & 0.0214 & 0.7660 & 0.0374 & 0.7259 & -0.0026 & 0.0545 & -0.6740 & 0.0000 & -0.7286 & 0.6765 & -0.0521 & 0.0000 & -0.7286 & 0.7185 & -0.0101 & 0.2482 & -0.4803 & 0.3669 & -0.3617 & 0.3696 & -0.3590 & 0.7447 & 0.0161 & 0.0000 & -0.7286 & 0.7465 & 0.0179 & 0.7656 & 0.0371 & 0.0672 & -0.6613 & 0.7986 & 0.0700 & 0.6923 & -0.0363 & 0.7279 & -0.0006 & 0.7021 & -0.0264 & 0.7630 & 0.0344 \\
+Achievements/make\_wood\_pickaxe & 0.9357 & 0.9444 & 0.0087 & 0.9265 & -0.0092 & 0.0000 & -0.9357 & 0.9286 & -0.0071 & 0.9571 & 0.0214 & 0.9362 & 0.0005 & 0.9556 & 0.0198 & 0.3636 & -0.5721 & 0.0000 & -0.9357 & 0.9412 & 0.0055 & 0.0000 & -0.9357 & 0.9481 & 0.0124 & 0.6950 & -0.2407 & 0.7554 & -0.1803 & 0.7971 & -0.1386 & 0.9220 & -0.0137 & 0.0000 & -0.9357 & 0.9155 & -0.0202 & 0.9297 & -0.0060 & 0.3445 & -0.5912 & 0.9353 & -0.0005 & 0.9301 & -0.0056 & 0.9412 & 0.0055 & 0.9220 & -0.0137 & 0.9481 & 0.0124 \\
+Achievements/make\_wood\_sword & 0.6714 & 0.6587 & -0.0127 & 0.6765 & 0.0050 & 0.0000 & -0.6714 & 0.6786 & 0.0071 & 0.7357 & 0.0643 & 0.6879 & 0.0165 & 0.6667 & -0.0048 & 0.3091 & -0.3623 & 0.0000 & -0.6714 & 0.7132 & 0.0418 & 0.0000 & -0.6714 & 0.6889 & 0.0175 & 0.6596 & -0.0119 & 0.7554 & 0.0840 & 0.7971 & 0.1257 & 0.6738 & 0.0023 & 0.0000 & -0.6714 & 0.7254 & 0.0539 & 0.6875 & 0.0161 & 0.1176 & -0.5538 & 0.6978 & 0.0264 & 0.6154 & -0.0560 & 0.7500 & 0.0786 & 0.6738 & 0.0023 & 0.7481 & 0.0767 \\
+Achievements/place\_furnace & 0.8000 & 0.7778 & -0.0222 & 0.8309 & 0.0309 & 0.0000 & -0.8000 & 0.8214 & 0.0214 & 0.8214 & 0.0214 & 0.8156 & 0.0156 & 0.8148 & 0.0148 & 0.0303 & -0.7697 & 0.0000 & -0.8000 & 0.7574 & -0.0426 & 0.0000 & -0.8000 & 0.8000 & -0.0000 & 0.1986 & -0.6014 & 0.4388 & -0.3612 & 0.4565 & -0.3435 & 0.7872 & -0.0128 & 0.0000 & -0.8000 & 0.7958 & -0.0042 & 0.7969 & -0.0031 & 0.2017 & -0.5983 & 0.8417 & 0.0417 & 0.7902 & -0.0098 & 0.8088 & 0.0088 & 0.8298 & 0.0298 & 0.8000 & -0.0000 \\
+Achievements/place\_plant & 0.5929 & 0.6508 & 0.0579 & 0.7059 & 0.1130 & 0.0000 & -0.5929 & 0.5786 & -0.0143 & 0.6286 & 0.0357 & 0.6950 & 0.1022 & 0.6519 & 0.0590 & 0.0242 & -0.5686 & 0.0000 & -0.5929 & 0.6324 & 0.0395 & 0.0000 & -0.5929 & 0.6593 & 0.0664 & 0.0426 & -0.5503 & 0.0791 & -0.5137 & 0.1014 & -0.4914 & 0.6879 & 0.0951 & 0.0000 & -0.5929 & 0.6901 & 0.0973 & 0.6719 & 0.0790 & 0.9412 & 0.3483 & 0.6906 & 0.0978 & 0.7063 & 0.1134 & 0.6250 & 0.0321 & 0.6667 & 0.0738 & 0.6519 & 0.0590 \\
+Achievements/place\_stone & 0.6429 & 0.6429 & 0.0000 & 0.7279 & 0.0851 & 0.0000 & -0.6429 & 0.7143 & 0.0714 & 0.7071 & 0.0643 & 0.6950 & 0.0522 & 0.7185 & 0.0757 & 0.0485 & -0.5944 & 0.0000 & -0.6429 & 0.6397 & -0.0032 & 0.0000 & -0.6429 & 0.6444 & 0.0016 & 0.2270 & -0.4159 & 0.2806 & -0.3623 & 0.2681 & -0.3747 & 0.6879 & 0.0451 & 0.0000 & -0.6429 & 0.7254 & 0.0825 & 0.7187 & 0.0759 & 0.1261 & -0.5168 & 0.6259 & -0.0170 & 0.6643 & 0.0215 & 0.6103 & -0.0326 & 0.6950 & 0.0522 & 0.7333 & 0.0905 \\
+Achievements/place\_table & 0.9786 & 0.9683 & -0.0103 & 0.9853 & 0.0067 & 0.0000 & -0.9786 & 0.9571 & -0.0214 & 0.9643 & -0.0143 & 0.9433 & -0.0353 & 0.9778 & -0.0008 & 0.4788 & -0.4998 & 0.0000 & -0.9786 & 0.9853 & 0.0067 & 0.0000 & -0.9786 & 0.9704 & -0.0082 & 0.8014 & -0.1772 & 0.8561 & -0.1225 & 0.8623 & -0.1163 & 0.9504 & -0.0282 & 0.0000 & -0.9786 & 0.9507 & -0.0279 & 0.9922 & 0.0136 & 0.4538 & -0.5248 & 0.9712 & -0.0073 & 0.9580 & -0.0205 & 0.9779 & -0.0006 & 0.9645 & -0.0140 & 1.0000 & 0.0214 \\
+Achievements/wake\_up & 0.0929 & 0.1190 & 0.0262 & 0.0735 & -0.0193 & 0.0000 & -0.0929 & 0.0786 & -0.0143 & 0.1000 & 0.0071 & 0.0709 & -0.0219 & 0.0963 & 0.0034 & 0.0061 & -0.0868 & 0.0000 & -0.0929 & 0.1324 & 0.0395 & 0.0000 & -0.0929 & 0.1481 & 0.0553 & 0.2199 & 0.1270 & 0.1583 & 0.0654 & 0.2391 & 0.1463 & 0.1064 & 0.0135 & 0.0000 & -0.0929 & 0.1197 & 0.0269 & 0.0547 & -0.0382 & 0.0000 & -0.0929 & 0.0863 & -0.0065 & 0.0699 & -0.0229 & 0.0809 & -0.0120 & 0.0496 & -0.0432 & 0.1185 & 0.0257 \\
+\bottomrule
+\end{tabular}
+\end{table}
\ No newline at end of file
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/forgetting_analysis.csv b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/forgetting_analysis.csv
new file mode 100644
index 0000000000000000000000000000000000000000..dad9791fe403844255f18bdb5b5b7fe38e1bf881
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/forgetting_analysis.csv
@@ -0,0 +1,26 @@
+Method,First_Collapse_Iter,Min_Score,Recovery_Score,Recovered
+baseline_rl,never,9.7303,10.8216,N/A
+kl_penalty,never,9.8106,10.988,N/A
+ewc,never,9.5767,10.7862,N/A
+llrd,never,10.0296,10.7272,N/A
+lora,25,-0.9283,-0.9103,N
+mixed_replay,never,10.0754,10.9151,N/A
+trust_region_kl,never,9.8786,10.6206,N/A
+t_curriculum,never,10.1262,10.9168,N/A
+entropy_bonus,never,9.8287,10.919,N/A
+gradient_surgery,never,9.8682,10.8241,N/A
+advantage_clip,never,9.737,10.7793,N/A
+normalized_adv,250,5.3672,5.0739,N
+bc_wins,never,9.8543,10.8676,N/A
+low_t,never,9.9204,10.6598,N/A
+frozen_backbone,75,-0.9247,0.2878,N
+head_only,75,-0.9247,0.2743,N
+attention_only,75,-0.9247,0.5447,N
+ffn_only,200,3.0321,1.8791,N
+layer_ablation_top1,175,6.2978,4.0468,N
+layer_ablation_top2,225,7.2379,6.3875,N
+layer_ablation_top3,175,7.6156,6.0164,Y
+reward_filtering,never,9.8508,11.0502,N/A
+running_stats,never,10.0158,10.9616,N/A
+action_diversity,never,9.9356,11.0773,N/A
+reward_model,never,9.9416,10.8205,N/A
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/forgetting_analysis.tex b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/forgetting_analysis.tex
new file mode 100644
index 0000000000000000000000000000000000000000..e5efff8ed7de725f9451c5f3386ff45801fc9907
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/forgetting_analysis.tex
@@ -0,0 +1,36 @@
+\begin{table}[htbp]
+\centering
+\caption{Catastrophic forgetting timeline.}
+\label{tab:forgetting}
+\begin{tabular}{lrrrr}
+\toprule
+\textbf{Method} & \textbf{First\_Collapse\_Iter} & \textbf{Min\_Score} & \textbf{Recovery\_Score} & \textbf{Recovered} \\
+\midrule
+baseline\_rl & never & 9.7303 & 10.8216 & N/A \\
+kl\_penalty & never & 9.8106 & 10.9880 & N/A \\
+ewc & never & 9.5767 & 10.7862 & N/A \\
+llrd & never & 10.0296 & 10.7272 & N/A \\
+lora & 25 & -0.9283 & -0.9103 & N \\
+mixed\_replay & never & 10.0754 & 10.9151 & N/A \\
+trust\_region\_kl & never & 9.8786 & 10.6206 & N/A \\
+t\_curriculum & never & 10.1262 & 10.9168 & N/A \\
+entropy\_bonus & never & 9.8287 & 10.9190 & N/A \\
+gradient\_surgery & never & 9.8682 & 10.8241 & N/A \\
+advantage\_clip & never & 9.7370 & 10.7793 & N/A \\
+normalized\_adv & 250 & 5.3672 & 5.0739 & N \\
+bc\_wins & never & 9.8543 & 10.8676 & N/A \\
+low\_t & never & 9.9204 & 10.6598 & N/A \\
+frozen\_backbone & 75 & -0.9247 & 0.2878 & N \\
+head\_only & 75 & -0.9247 & 0.2743 & N \\
+attention\_only & 75 & -0.9247 & 0.5447 & N \\
+ffn\_only & 200 & 3.0321 & 1.8791 & N \\
+layer\_ablation\_top1 & 175 & 6.2978 & 4.0468 & N \\
+layer\_ablation\_top2 & 225 & 7.2379 & 6.3875 & N \\
+layer\_ablation\_top3 & 175 & 7.6156 & 6.0164 & Y \\
+reward\_filtering & never & 9.8508 & 11.0502 & N/A \\
+running\_stats & never & 10.0158 & 10.9616 & N/A \\
+action\_diversity & never & 9.9356 & 11.0773 & N/A \\
+reward\_model & never & 9.9416 & 10.8205 & N/A \\
+\bottomrule
+\end{tabular}
+\end{table}
\ No newline at end of file
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/gradient_analysis.csv b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/gradient_analysis.csv
new file mode 100644
index 0000000000000000000000000000000000000000..ad36c2fe857541bcd1eafa2f712b994ced69751a
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/gradient_analysis.csv
@@ -0,0 +1,26 @@
+Method,Mean_Grad_Align,Final_Grad_Align,Trend,Mean_KL_Drift,Final_KL_Drift
+trust_region_kl,0.0555,0.1171,up,0.059783,0.064425
+kl_penalty,0.0206,0.0969,up,0.14607,0.167528
+mixed_replay,-0.0176,0.0927,up,0.158821,0.173193
+entropy_bonus,0.0139,0.0891,up,0.16366,0.189308
+running_stats,0.023,0.0808,down,0.180406,0.21801
+reward_filtering,0.0208,0.0775,up,0.114046,0.151835
+lora,0.0827,0.077,down,594365.340527,1679591.0
+bc_wins,0.0092,0.0702,down,0.19145,0.219727
+llrd,0.0125,0.068,up,0.137441,0.188783
+advantage_clip,-0.0014,0.0667,up,0.189775,0.219086
+gradient_surgery,0.0153,0.065,up,0.159825,0.209463
+baseline_rl,0.0153,0.0645,up,0.159842,0.2096
+action_diversity,0.0153,0.0644,up,0.159893,0.209855
+t_curriculum,0.0152,0.0495,up,0.197279,0.225474
+reward_model,0.0169,0.0449,down,0.171694,0.222742
+normalized_adv,0.0111,0.035,up,33.724305,63.2407
+layer_ablation_top3,0.0474,0.0302,down,12.019569,16.208744
+low_t,0.0246,0.0278,down,0.13292,0.187926
+attention_only,0.0171,0.0192,up,110.250885,90.191254
+frozen_backbone,0.0152,0.0109,down,125.046066,115.5261
+head_only,0.0151,0.0109,down,125.37802,115.826111
+ffn_only,0.0466,0.0048,down,15.351683,23.891375
+layer_ablation_top1,0.0475,-0.0028,down,11.511555,13.324247
+layer_ablation_top2,0.0461,-0.004,down,11.360045,9.672604
+ewc,0.0251,-0.0246,down,0.153398,0.194954
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/gradient_analysis.tex b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/gradient_analysis.tex
new file mode 100644
index 0000000000000000000000000000000000000000..df7e4d325bf6336985d9ac4e235d24e039c23549
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/gradient_analysis.tex
@@ -0,0 +1,36 @@
+\begin{table}[htbp]
+\centering
+\caption{Gradient alignment and representation drift analysis.}
+\label{tab:gradient}
+\begin{tabular}{lrrrrr}
+\toprule
+\textbf{Method} & \textbf{Mean\_Grad\_Align} & \textbf{Final\_Grad\_Align} & \textbf{Trend} & \textbf{Mean\_KL\_Drift} & \textbf{Final\_KL\_Drift} \\
+\midrule
+trust\_region\_kl & 0.0555 & 0.1171 & up & 0.0598 & 0.0644 \\
+kl\_penalty & 0.0206 & 0.0969 & up & 0.1461 & 0.1675 \\
+mixed\_replay & -0.0176 & 0.0927 & up & 0.1588 & 0.1732 \\
+entropy\_bonus & 0.0139 & 0.0891 & up & 0.1637 & 0.1893 \\
+running\_stats & 0.0230 & 0.0808 & down & 0.1804 & 0.2180 \\
+reward\_filtering & 0.0208 & 0.0775 & up & 0.1140 & 0.1518 \\
+lora & 0.0827 & 0.0770 & down & 594365.3405 & 1679591.0000 \\
+bc\_wins & 0.0092 & 0.0702 & down & 0.1915 & 0.2197 \\
+llrd & 0.0125 & 0.0680 & up & 0.1374 & 0.1888 \\
+advantage\_clip & -0.0014 & 0.0667 & up & 0.1898 & 0.2191 \\
+gradient\_surgery & 0.0153 & 0.0650 & up & 0.1598 & 0.2095 \\
+baseline\_rl & 0.0153 & 0.0645 & up & 0.1598 & 0.2096 \\
+action\_diversity & 0.0153 & 0.0644 & up & 0.1599 & 0.2099 \\
+t\_curriculum & 0.0152 & 0.0495 & up & 0.1973 & 0.2255 \\
+reward\_model & 0.0169 & 0.0449 & down & 0.1717 & 0.2227 \\
+normalized\_adv & 0.0111 & 0.0350 & up & 33.7243 & 63.2407 \\
+layer\_ablation\_top3 & 0.0474 & 0.0302 & down & 12.0196 & 16.2087 \\
+low\_t & 0.0246 & 0.0278 & down & 0.1329 & 0.1879 \\
+attention\_only & 0.0171 & 0.0192 & up & 110.2509 & 90.1913 \\
+frozen\_backbone & 0.0152 & 0.0109 & down & 125.0461 & 115.5261 \\
+head\_only & 0.0151 & 0.0109 & down & 125.3780 & 115.8261 \\
+ffn\_only & 0.0466 & 0.0048 & down & 15.3517 & 23.8914 \\
+layer\_ablation\_top1 & 0.0475 & -0.0028 & down & 11.5116 & 13.3242 \\
+layer\_ablation\_top2 & 0.0461 & -0.0040 & down & 11.3600 & 9.6726 \\
+ewc & 0.0251 & -0.0246 & down & 0.1534 & 0.1950 \\
+\bottomrule
+\end{tabular}
+\end{table}
\ No newline at end of file
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/group_summary.csv b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/group_summary.csv
new file mode 100644
index 0000000000000000000000000000000000000000..8cdde77ea4294628a48d3d8d73054c32d6e33782
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/group_summary.csv
@@ -0,0 +1,6 @@
+Group,N,Mean,Best,Worst,StdDev
+Baseline,1,10.8216,10.8216,10.8216,0.0
+A,6,8.8545,10.988,-0.9103,4.3686
+B,7,10.0058,10.919,5.0739,2.0151
+C,7,2.7767,6.3875,0.2743,2.4897
+D,4,10.9774,11.0773,10.8205,0.1002
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/group_summary.tex b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/group_summary.tex
new file mode 100644
index 0000000000000000000000000000000000000000..0f34449eb7426fb30c3a1dc4dad7a40c6e004903
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/group_summary.tex
@@ -0,0 +1,16 @@
+\begin{table}[htbp]
+\centering
+\caption{Group summary statistics.}
+\label{tab:group_summary}
+\begin{tabular}{lrrrrr}
+\toprule
+\textbf{Group} & \textbf{N} & \textbf{Mean} & \textbf{Best} & \textbf{Worst} & \textbf{StdDev} \\
+\midrule
+Baseline & 1 & 10.8216 & 10.8216 & 10.8216 & 0.0000 \\
+A & 6 & 8.8545 & 10.9880 & -0.9103 & 4.3686 \\
+B & 7 & 10.0058 & 10.9190 & 5.0739 & 2.0151 \\
+C & 7 & 2.7767 & 6.3875 & 0.2743 & 2.4897 \\
+D & 4 & 10.9774 & 11.0773 & 10.8205 & 0.1002 \\
+\bottomrule
+\end{tabular}
+\end{table}
\ No newline at end of file
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/hypothesis_verdict.csv b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/hypothesis_verdict.csv
new file mode 100644
index 0000000000000000000000000000000000000000..da483f3622b7d1eeaac2269a8a0cda97175c0403
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/hypothesis_verdict.csv
@@ -0,0 +1,26 @@
+Ablation,Group,Hypothesis,Result,Conclusion
+baseline_rl,Baseline,Diagnoses whether the RL signal alone causes collapse,IMPROVEMENT,Hypothesis SUPPORTED — this intervention helps
+kl_penalty,A,If this helps: catastrophic forgetting is the primary cause; soft regularisation...,IMPROVEMENT,Hypothesis SUPPORTED — this intervention helps
+ewc,A,If EWC helps: forgetting pretrained representations is the proximate cause,IMPROVEMENT,Hypothesis SUPPORTED — this intervention helps
+llrd,A,If LLRD helps: deep gradient flow into early layers corrupts representations,IMPROVEMENT,Hypothesis SUPPORTED — this intervention helps
+lora,A,If LoRA works: too many unconstrained degrees of freedom cause collapse,COLLAPSE,Hypothesis REFUTED — intervention did not prevent collapse
+mixed_replay,A,If mixed replay helps: online data distribution alone is too corrupted,IMPROVEMENT,Hypothesis SUPPORTED — this intervention helps
+trust_region_kl,A,If hard constraint helps: soft KL is insufficient — a hard boundary is needed,IMPROVEMENT,Hypothesis SUPPORTED — this intervention helps
+t_curriculum,B,If curriculum helps: ordering of learning signals matters,IMPROVEMENT,Hypothesis SUPPORTED — this intervention helps
+entropy_bonus,B,If entropy bonus helps: collapse is mode-collapse; not a gradient problem,IMPROVEMENT,Hypothesis SUPPORTED — this intervention helps
+gradient_surgery,B,If PCGrad helps: gradients are conflicting and resolvable by projection,IMPROVEMENT,Hypothesis SUPPORTED — this intervention helps
+advantage_clip,B,If clipping helps: large advantage magnitudes destabilise training,IMPROVEMENT,Hypothesis SUPPORTED — this intervention helps
+normalized_adv,B,If std normalisation helps: simple mean normalisation is too loose,COLLAPSE,Hypothesis REFUTED — intervention did not prevent collapse
+bc_wins,B,If BC on wins helps: the return weighting is the specific cause,IMPROVEMENT,Hypothesis SUPPORTED — this intervention helps
+low_t,B,If low-t helps: high-t (coarse-structure) gradients are biased,IMPROVEMENT,Hypothesis SUPPORTED — this intervention helps
+frozen_backbone,C,If frozen backbone helps: deep gradient flow into backbone causes collapse,COLLAPSE,Hypothesis REFUTED — intervention did not prevent collapse
+head_only,C,If head-only works: backbone representations are fine; only decision boundary ne...,COLLAPSE,Hypothesis REFUTED — intervention did not prevent collapse
+attention_only,C,"If attention-only works: model needs routing updates, not feature updates",COLLAPSE,Hypothesis REFUTED — intervention did not prevent collapse
+ffn_only,C,If FFN-only works: stored knowledge (FFN as memory) needs updating; not attentio...,COLLAPSE,Hypothesis REFUTED — intervention did not prevent collapse
+layer_ablation_top1,C,Minimal unfrozen depth needed; collapse depth correlates with gradient flow dept...,COLLAPSE,Hypothesis REFUTED — intervention did not prevent collapse
+layer_ablation_top2,C,Minimal unfrozen depth needed; collapse depth correlates with gradient flow dept...,COLLAPSE,Hypothesis REFUTED — intervention did not prevent collapse
+layer_ablation_top3,C,Minimal unfrozen depth needed; collapse depth correlates with gradient flow dept...,COLLAPSE,Hypothesis REFUTED — intervention did not prevent collapse
+reward_filtering,D,If filtering helps: noisy/low-return data poisons gradients,IMPROVEMENT,Hypothesis SUPPORTED — this intervention helps
+running_stats,D,If running stats help: batch normalisation is too noisy for small batches,IMPROVEMENT,Hypothesis SUPPORTED — this intervention helps
+action_diversity,D,If diversity filtering helps: degenerate PPO plans corrupt training,IMPROVEMENT,Hypothesis SUPPORTED — this intervention helps
+reward_model,D,If reward model helps: raw returns are too sparse; learned model smooths signal,IMPROVEMENT,Hypothesis SUPPORTED — this intervention helps
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/hypothesis_verdict.tex b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/hypothesis_verdict.tex
new file mode 100644
index 0000000000000000000000000000000000000000..006475690d5e58594869b4f5d012224683034199
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/hypothesis_verdict.tex
@@ -0,0 +1,36 @@
+\begin{table}[htbp]
+\centering
+\caption{Hypothesis verdict per ablation.}
+\label{tab:hypothesis}
+\begin{tabular}{lrrrr}
+\toprule
+\textbf{Ablation} & \textbf{Group} & \textbf{Hypothesis} & \textbf{Result} & \textbf{Conclusion} \\
+\midrule
+baseline\_rl & Baseline & Diagnoses whether the RL signal alone causes collapse & IMPROVEMENT & Hypothesis SUPPORTED — this intervention helps \\
+kl\_penalty & A & If this helps: catastrophic forgetting is the primary cause; soft regularisation... & IMPROVEMENT & Hypothesis SUPPORTED — this intervention helps \\
+ewc & A & If EWC helps: forgetting pretrained representations is the proximate cause & IMPROVEMENT & Hypothesis SUPPORTED — this intervention helps \\
+llrd & A & If LLRD helps: deep gradient flow into early layers corrupts representations & IMPROVEMENT & Hypothesis SUPPORTED — this intervention helps \\
+lora & A & If LoRA works: too many unconstrained degrees of freedom cause collapse & COLLAPSE & Hypothesis REFUTED — intervention did not prevent collapse \\
+mixed\_replay & A & If mixed replay helps: online data distribution alone is too corrupted & IMPROVEMENT & Hypothesis SUPPORTED — this intervention helps \\
+trust\_region\_kl & A & If hard constraint helps: soft KL is insufficient — a hard boundary is needed & IMPROVEMENT & Hypothesis SUPPORTED — this intervention helps \\
+t\_curriculum & B & If curriculum helps: ordering of learning signals matters & IMPROVEMENT & Hypothesis SUPPORTED — this intervention helps \\
+entropy\_bonus & B & If entropy bonus helps: collapse is mode-collapse; not a gradient problem & IMPROVEMENT & Hypothesis SUPPORTED — this intervention helps \\
+gradient\_surgery & B & If PCGrad helps: gradients are conflicting and resolvable by projection & IMPROVEMENT & Hypothesis SUPPORTED — this intervention helps \\
+advantage\_clip & B & If clipping helps: large advantage magnitudes destabilise training & IMPROVEMENT & Hypothesis SUPPORTED — this intervention helps \\
+normalized\_adv & B & If std normalisation helps: simple mean normalisation is too loose & COLLAPSE & Hypothesis REFUTED — intervention did not prevent collapse \\
+bc\_wins & B & If BC on wins helps: the return weighting is the specific cause & IMPROVEMENT & Hypothesis SUPPORTED — this intervention helps \\
+low\_t & B & If low-t helps: high-t (coarse-structure) gradients are biased & IMPROVEMENT & Hypothesis SUPPORTED — this intervention helps \\
+frozen\_backbone & C & If frozen backbone helps: deep gradient flow into backbone causes collapse & COLLAPSE & Hypothesis REFUTED — intervention did not prevent collapse \\
+head\_only & C & If head-only works: backbone representations are fine; only decision boundary ne... & COLLAPSE & Hypothesis REFUTED — intervention did not prevent collapse \\
+attention\_only & C & If attention-only works: model needs routing updates, not feature updates & COLLAPSE & Hypothesis REFUTED — intervention did not prevent collapse \\
+ffn\_only & C & If FFN-only works: stored knowledge (FFN as memory) needs updating; not attentio... & COLLAPSE & Hypothesis REFUTED — intervention did not prevent collapse \\
+layer\_ablation\_top1 & C & Minimal unfrozen depth needed; collapse depth correlates with gradient flow dept... & COLLAPSE & Hypothesis REFUTED — intervention did not prevent collapse \\
+layer\_ablation\_top2 & C & Minimal unfrozen depth needed; collapse depth correlates with gradient flow dept... & COLLAPSE & Hypothesis REFUTED — intervention did not prevent collapse \\
+layer\_ablation\_top3 & C & Minimal unfrozen depth needed; collapse depth correlates with gradient flow dept... & COLLAPSE & Hypothesis REFUTED — intervention did not prevent collapse \\
+reward\_filtering & D & If filtering helps: noisy/low-return data poisons gradients & IMPROVEMENT & Hypothesis SUPPORTED — this intervention helps \\
+running\_stats & D & If running stats help: batch normalisation is too noisy for small batches & IMPROVEMENT & Hypothesis SUPPORTED — this intervention helps \\
+action\_diversity & D & If diversity filtering helps: degenerate PPO plans corrupt training & IMPROVEMENT & Hypothesis SUPPORTED — this intervention helps \\
+reward\_model & D & If reward model helps: raw returns are too sparse; learned model smooths signal & IMPROVEMENT & Hypothesis SUPPORTED — this intervention helps \\
+\bottomrule
+\end{tabular}
+\end{table}
\ No newline at end of file
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/main_results.csv b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/main_results.csv
new file mode 100644
index 0000000000000000000000000000000000000000..721939fa2b87eef390cc81f7499fdb61204b85e3
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/main_results.csv
@@ -0,0 +1,26 @@
+Method,Group,Final_Score,Delta_vs_Pretrained,Delta_vs_Baseline_RL,Verdict
+action_diversity,D,11.0773,0.5487,0.2557,IMPROVEMENT
+reward_filtering,D,11.0502,0.5216,0.2286,IMPROVEMENT
+kl_penalty,A,10.988,0.4594,0.1664,IMPROVEMENT
+running_stats,D,10.9616,0.433,0.14,IMPROVEMENT
+entropy_bonus,B,10.919,0.3904,0.0974,IMPROVEMENT
+t_curriculum,B,10.9168,0.3882,0.0952,IMPROVEMENT
+mixed_replay,A,10.9151,0.3865,0.0935,IMPROVEMENT
+bc_wins,B,10.8676,0.339,0.046,IMPROVEMENT
+gradient_surgery,B,10.8241,0.2955,0.0025,IMPROVEMENT
+baseline_rl,Baseline,10.8216,0.293,0.0,IMPROVEMENT
+reward_model,D,10.8205,0.2919,-0.0011,IMPROVEMENT
+ewc,A,10.7862,0.2577,-0.0353,IMPROVEMENT
+advantage_clip,B,10.7793,0.2507,-0.0423,IMPROVEMENT
+llrd,A,10.7272,0.1986,-0.0944,IMPROVEMENT
+low_t,B,10.6598,0.1312,-0.1618,IMPROVEMENT
+trust_region_kl,A,10.6206,0.092,-0.201,IMPROVEMENT
+layer_ablation_top2,C,6.3875,-4.1411,-4.4341,COLLAPSE
+layer_ablation_top3,C,6.0164,-4.5122,-4.8052,COLLAPSE
+normalized_adv,B,5.0739,-5.4547,-5.7477,COLLAPSE
+layer_ablation_top1,C,4.0468,-6.4818,-6.7748,COLLAPSE
+ffn_only,C,1.8791,-8.6494,-8.9424,COLLAPSE
+attention_only,C,0.5447,-9.9838,-10.2768,COLLAPSE
+frozen_backbone,C,0.2878,-10.2408,-10.5338,COLLAPSE
+head_only,C,0.2743,-10.2543,-10.5473,COLLAPSE
+lora,A,-0.9103,-11.4389,-11.7319,COLLAPSE
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/main_results.tex b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/main_results.tex
new file mode 100644
index 0000000000000000000000000000000000000000..d80b09daed497b209dbb19e541a009300337b4e7
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/main_results.tex
@@ -0,0 +1,36 @@
+\begin{table}[htbp]
+\centering
+\caption{Main ablation results.}
+\label{tab:main_results}
+\begin{tabular}{lrrrrr}
+\toprule
+\textbf{Method} & \textbf{Group} & \textbf{Final\_Score} & \textbf{Delta\_vs\_Pretrained} & \textbf{Delta\_vs\_Baseline\_RL} & \textbf{Verdict} \\
+\midrule
+action\_diversity & D & 11.0773 & 0.5487 & 0.2557 & IMPROVEMENT \\
+reward\_filtering & D & 11.0502 & 0.5216 & 0.2286 & IMPROVEMENT \\
+kl\_penalty & A & 10.9880 & 0.4594 & 0.1664 & IMPROVEMENT \\
+running\_stats & D & 10.9616 & 0.4330 & 0.1400 & IMPROVEMENT \\
+entropy\_bonus & B & 10.9190 & 0.3904 & 0.0974 & IMPROVEMENT \\
+t\_curriculum & B & 10.9168 & 0.3882 & 0.0952 & IMPROVEMENT \\
+mixed\_replay & A & 10.9151 & 0.3865 & 0.0935 & IMPROVEMENT \\
+bc\_wins & B & 10.8676 & 0.3390 & 0.0460 & IMPROVEMENT \\
+gradient\_surgery & B & 10.8241 & 0.2955 & 0.0025 & IMPROVEMENT \\
+baseline\_rl & Baseline & 10.8216 & 0.2930 & 0.0000 & IMPROVEMENT \\
+reward\_model & D & 10.8205 & 0.2919 & -0.0011 & IMPROVEMENT \\
+ewc & A & 10.7862 & 0.2577 & -0.0353 & IMPROVEMENT \\
+advantage\_clip & B & 10.7793 & 0.2507 & -0.0423 & IMPROVEMENT \\
+llrd & A & 10.7272 & 0.1986 & -0.0944 & IMPROVEMENT \\
+low\_t & B & 10.6598 & 0.1312 & -0.1618 & IMPROVEMENT \\
+trust\_region\_kl & A & 10.6206 & 0.0920 & -0.2010 & IMPROVEMENT \\
+layer\_ablation\_top2 & C & 6.3875 & -4.1411 & -4.4341 & COLLAPSE \\
+layer\_ablation\_top3 & C & 6.0164 & -4.5122 & -4.8052 & COLLAPSE \\
+normalized\_adv & B & 5.0739 & -5.4547 & -5.7477 & COLLAPSE \\
+layer\_ablation\_top1 & C & 4.0468 & -6.4818 & -6.7748 & COLLAPSE \\
+ffn\_only & C & 1.8791 & -8.6494 & -8.9424 & COLLAPSE \\
+attention\_only & C & 0.5447 & -9.9838 & -10.2768 & COLLAPSE \\
+frozen\_backbone & C & 0.2878 & -10.2408 & -10.5338 & COLLAPSE \\
+head\_only & C & 0.2743 & -10.2543 & -10.5473 & COLLAPSE \\
+lora & A & -0.9103 & -11.4389 & -11.7319 & COLLAPSE \\
+\bottomrule
+\end{tabular}
+\end{table}
\ No newline at end of file
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/per_env.csv b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/per_env.csv
new file mode 100644
index 0000000000000000000000000000000000000000..2e2027d5d3ce9b158c5a0f050b710caf5b6a662f
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/per_env.csv
@@ -0,0 +1,27 @@
+Method,Achievements/collect_coal,Achievements/collect_diamond,Achievements/collect_drink,Achievements/collect_iron,Achievements/collect_sapling,Achievements/collect_stone,Achievements/collect_wood,Achievements/defeat_skeleton,Achievements/defeat_zombie,Achievements/eat_cow,Achievements/eat_plant,Achievements/make_iron_pickaxe,Achievements/make_iron_sword,Achievements/make_stone_pickaxe,Achievements/make_stone_sword,Achievements/make_wood_pickaxe,Achievements/make_wood_sword,Achievements/place_furnace,Achievements/place_plant,Achievements/place_stone,Achievements/place_table,Achievements/wake_up
+pretrained,0.4857,0.0,0.3857,0.3357,0.9071,0.9,1.0,0.1571,0.6214,0.3357,0.0,0.0214,0.1214,0.65,0.7286,0.9357,0.6714,0.8,0.5929,0.6429,0.9786,0.0929
+action_diversity,0.5079,0.0,0.4286,0.3095,0.9127,0.8651,0.9921,0.254,0.6508,0.4603,0.0,0.0238,0.1349,0.6746,0.7302,0.9444,0.6587,0.7778,0.6508,0.6429,0.9683,0.119
+advantage_clip,0.4338,0.0,0.4044,0.3309,0.9265,0.8824,0.9926,0.2279,0.5956,0.3603,0.0,0.0221,0.125,0.6029,0.7353,0.9265,0.6765,0.8309,0.7059,0.7279,0.9853,0.0735
+attention_only,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
+baseline_rl,0.4429,0.0,0.3857,0.3786,0.8643,0.8929,0.9929,0.2143,0.7214,0.3429,0.0,0.0214,0.1571,0.6714,0.7214,0.9286,0.6786,0.8214,0.5786,0.7143,0.9571,0.0786
+bc_wins,0.4643,0.0,0.3714,0.3571,0.8571,0.8929,0.9929,0.2071,0.5714,0.3786,0.0,0.0,0.1571,0.6929,0.75,0.9571,0.7357,0.8214,0.6286,0.7071,0.9643,0.1
+entropy_bonus,0.4255,0.0,0.3333,0.2979,0.9362,0.8936,0.9858,0.1844,0.695,0.383,0.0,0.0284,0.1206,0.6525,0.766,0.9362,0.6879,0.8156,0.695,0.695,0.9433,0.0709
+ewc,0.5407,0.0074,0.437,0.3556,0.9037,0.9037,1.0,0.2,0.6815,0.3333,0.0,0.037,0.1259,0.7407,0.7259,0.9556,0.6667,0.8148,0.6519,0.7185,0.9778,0.0963
+ffn_only,0.0303,0.0,0.2848,0.0,0.5576,0.1576,0.9212,0.0242,0.2606,0.1152,0.0,0.0,0.0,0.0242,0.0545,0.3636,0.3091,0.0303,0.0242,0.0485,0.4788,0.0061
+frozen_backbone,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
+gradient_surgery,0.4118,0.0,0.4191,0.3088,0.9191,0.8603,1.0,0.2132,0.6912,0.375,0.0,0.0074,0.0956,0.6838,0.6765,0.9412,0.7132,0.7574,0.6324,0.6397,0.9853,0.1324
+head_only,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
+kl_penalty,0.4815,0.0,0.3259,0.3704,0.9333,0.8815,1.0,0.2296,0.6519,0.3926,0.0,0.0222,0.1778,0.6815,0.7185,0.9481,0.6889,0.8,0.6593,0.6444,0.9704,0.1481
+layer_ablation_top1,0.1702,0.0,0.3333,0.0213,0.7801,0.5106,1.0,0.078,0.5887,0.2057,0.0,0.0,0.0,0.1418,0.2482,0.695,0.6596,0.1986,0.0426,0.227,0.8014,0.2199
+layer_ablation_top2,0.2806,0.0,0.4245,0.1151,0.8993,0.5899,1.0,0.0791,0.5971,0.2446,0.0,0.0144,0.0288,0.3525,0.3669,0.7554,0.7554,0.4388,0.0791,0.2806,0.8561,0.1583
+layer_ablation_top3,0.1812,0.0,0.4203,0.0652,0.8623,0.6522,0.9928,0.1087,0.6884,0.2971,0.0,0.0072,0.0,0.2971,0.3696,0.7971,0.7971,0.4565,0.1014,0.2681,0.8623,0.2391
+llrd,0.4468,0.0,0.4043,0.2979,0.8936,0.8652,0.9929,0.0993,0.6879,0.2908,0.0,0.0213,0.0922,0.6525,0.7447,0.922,0.6738,0.7872,0.6879,0.6879,0.9504,0.1064
+lora,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
+low_t,0.4085,0.0,0.3944,0.2676,0.9296,0.8873,0.9859,0.1972,0.6761,0.3732,0.0,0.0141,0.0845,0.6408,0.7465,0.9155,0.7254,0.7958,0.6901,0.7254,0.9507,0.1197
+mixed_replay,0.5391,0.0,0.4141,0.4062,0.9141,0.9141,1.0,0.2266,0.6484,0.4219,0.0,0.0078,0.1875,0.6875,0.7656,0.9297,0.6875,0.7969,0.6719,0.7187,0.9922,0.0547
+normalized_adv,0.084,0.0,0.1849,0.0084,1.0,0.3025,0.958,0.0924,0.6134,0.2437,0.0,0.0,0.0,0.0252,0.0672,0.3445,0.1176,0.2017,0.9412,0.1261,0.4538,0.0
+reward_filtering,0.4317,0.0,0.3453,0.4029,0.9065,0.9065,0.9928,0.2446,0.6906,0.3022,0.0,0.0504,0.1367,0.7122,0.7986,0.9353,0.6978,0.8417,0.6906,0.6259,0.9712,0.0863
+reward_model,0.4615,0.0,0.3427,0.3427,0.9021,0.8951,1.0,0.2028,0.6643,0.3427,0.0,0.014,0.1538,0.6993,0.6923,0.9301,0.6154,0.7902,0.7063,0.6643,0.958,0.0699
+running_stats,0.5074,0.0,0.3676,0.4118,0.8897,0.9044,0.9926,0.1838,0.6397,0.375,0.0,0.0368,0.1471,0.6838,0.7279,0.9412,0.75,0.8088,0.625,0.6103,0.9779,0.0809
+t_curriculum,0.4397,0.0,0.3972,0.3546,0.9149,0.8865,0.9929,0.2553,0.6454,0.305,0.0,0.0213,0.1277,0.6454,0.7021,0.922,0.6738,0.8298,0.6667,0.695,0.9645,0.0496
+trust_region_kl,0.4074,0.0,0.4,0.2815,0.8889,0.9333,1.0,0.2296,0.6815,0.363,0.0,0.0,0.0815,0.7111,0.763,0.9481,0.7481,0.8,0.6519,0.7333,1.0,0.1185
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/per_env.tex b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/per_env.tex
new file mode 100644
index 0000000000000000000000000000000000000000..602c028fd2909525d26c154b7a1327dd780d38e2
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/per_env.tex
@@ -0,0 +1,37 @@
+\begin{table}[htbp]
+\centering
+\caption{Per-environment (per-achievement) win rates at final eval.}
+\label{tab:per_env}
+\begin{tabular}{lrrrrrrrrrrrrrrrrrrrrrr}
+\toprule
+\textbf{Method} & \textbf{Achievements/collect\_coal} & \textbf{Achievements/collect\_diamond} & \textbf{Achievements/collect\_drink} & \textbf{Achievements/collect\_iron} & \textbf{Achievements/collect\_sapling} & \textbf{Achievements/collect\_stone} & \textbf{Achievements/collect\_wood} & \textbf{Achievements/defeat\_skeleton} & \textbf{Achievements/defeat\_zombie} & \textbf{Achievements/eat\_cow} & \textbf{Achievements/eat\_plant} & \textbf{Achievements/make\_iron\_pickaxe} & \textbf{Achievements/make\_iron\_sword} & \textbf{Achievements/make\_stone\_pickaxe} & \textbf{Achievements/make\_stone\_sword} & \textbf{Achievements/make\_wood\_pickaxe} & \textbf{Achievements/make\_wood\_sword} & \textbf{Achievements/place\_furnace} & \textbf{Achievements/place\_plant} & \textbf{Achievements/place\_stone} & \textbf{Achievements/place\_table} & \textbf{Achievements/wake\_up} \\
+\midrule
+pretrained & 0.4857 & 0.0000 & 0.3857 & 0.3357 & 0.9071 & 0.9000 & 1.0000 & 0.1571 & 0.6214 & 0.3357 & 0.0000 & 0.0214 & 0.1214 & 0.6500 & 0.7286 & 0.9357 & 0.6714 & 0.8000 & 0.5929 & 0.6429 & 0.9786 & 0.0929 \\
+action\_diversity & 0.5079 & 0.0000 & 0.4286 & 0.3095 & 0.9127 & 0.8651 & 0.9921 & 0.2540 & 0.6508 & 0.4603 & 0.0000 & 0.0238 & 0.1349 & 0.6746 & 0.7302 & 0.9444 & 0.6587 & 0.7778 & 0.6508 & 0.6429 & 0.9683 & 0.1190 \\
+advantage\_clip & 0.4338 & 0.0000 & 0.4044 & 0.3309 & 0.9265 & 0.8824 & 0.9926 & 0.2279 & 0.5956 & 0.3603 & 0.0000 & 0.0221 & 0.1250 & 0.6029 & 0.7353 & 0.9265 & 0.6765 & 0.8309 & 0.7059 & 0.7279 & 0.9853 & 0.0735 \\
+attention\_only & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 \\
+baseline\_rl & 0.4429 & 0.0000 & 0.3857 & 0.3786 & 0.8643 & 0.8929 & 0.9929 & 0.2143 & 0.7214 & 0.3429 & 0.0000 & 0.0214 & 0.1571 & 0.6714 & 0.7214 & 0.9286 & 0.6786 & 0.8214 & 0.5786 & 0.7143 & 0.9571 & 0.0786 \\
+bc\_wins & 0.4643 & 0.0000 & 0.3714 & 0.3571 & 0.8571 & 0.8929 & 0.9929 & 0.2071 & 0.5714 & 0.3786 & 0.0000 & 0.0000 & 0.1571 & 0.6929 & 0.7500 & 0.9571 & 0.7357 & 0.8214 & 0.6286 & 0.7071 & 0.9643 & 0.1000 \\
+entropy\_bonus & 0.4255 & 0.0000 & 0.3333 & 0.2979 & 0.9362 & 0.8936 & 0.9858 & 0.1844 & 0.6950 & 0.3830 & 0.0000 & 0.0284 & 0.1206 & 0.6525 & 0.7660 & 0.9362 & 0.6879 & 0.8156 & 0.6950 & 0.6950 & 0.9433 & 0.0709 \\
+ewc & 0.5407 & 0.0074 & 0.4370 & 0.3556 & 0.9037 & 0.9037 & 1.0000 & 0.2000 & 0.6815 & 0.3333 & 0.0000 & 0.0370 & 0.1259 & 0.7407 & 0.7259 & 0.9556 & 0.6667 & 0.8148 & 0.6519 & 0.7185 & 0.9778 & 0.0963 \\
+ffn\_only & 0.0303 & 0.0000 & 0.2848 & 0.0000 & 0.5576 & 0.1576 & 0.9212 & 0.0242 & 0.2606 & 0.1152 & 0.0000 & 0.0000 & 0.0000 & 0.0242 & 0.0545 & 0.3636 & 0.3091 & 0.0303 & 0.0242 & 0.0485 & 0.4788 & 0.0061 \\
+frozen\_backbone & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 \\
+gradient\_surgery & 0.4118 & 0.0000 & 0.4191 & 0.3088 & 0.9191 & 0.8603 & 1.0000 & 0.2132 & 0.6912 & 0.3750 & 0.0000 & 0.0074 & 0.0956 & 0.6838 & 0.6765 & 0.9412 & 0.7132 & 0.7574 & 0.6324 & 0.6397 & 0.9853 & 0.1324 \\
+head\_only & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 \\
+kl\_penalty & 0.4815 & 0.0000 & 0.3259 & 0.3704 & 0.9333 & 0.8815 & 1.0000 & 0.2296 & 0.6519 & 0.3926 & 0.0000 & 0.0222 & 0.1778 & 0.6815 & 0.7185 & 0.9481 & 0.6889 & 0.8000 & 0.6593 & 0.6444 & 0.9704 & 0.1481 \\
+layer\_ablation\_top1 & 0.1702 & 0.0000 & 0.3333 & 0.0213 & 0.7801 & 0.5106 & 1.0000 & 0.0780 & 0.5887 & 0.2057 & 0.0000 & 0.0000 & 0.0000 & 0.1418 & 0.2482 & 0.6950 & 0.6596 & 0.1986 & 0.0426 & 0.2270 & 0.8014 & 0.2199 \\
+layer\_ablation\_top2 & 0.2806 & 0.0000 & 0.4245 & 0.1151 & 0.8993 & 0.5899 & 1.0000 & 0.0791 & 0.5971 & 0.2446 & 0.0000 & 0.0144 & 0.0288 & 0.3525 & 0.3669 & 0.7554 & 0.7554 & 0.4388 & 0.0791 & 0.2806 & 0.8561 & 0.1583 \\
+layer\_ablation\_top3 & 0.1812 & 0.0000 & 0.4203 & 0.0652 & 0.8623 & 0.6522 & 0.9928 & 0.1087 & 0.6884 & 0.2971 & 0.0000 & 0.0072 & 0.0000 & 0.2971 & 0.3696 & 0.7971 & 0.7971 & 0.4565 & 0.1014 & 0.2681 & 0.8623 & 0.2391 \\
+llrd & 0.4468 & 0.0000 & 0.4043 & 0.2979 & 0.8936 & 0.8652 & 0.9929 & 0.0993 & 0.6879 & 0.2908 & 0.0000 & 0.0213 & 0.0922 & 0.6525 & 0.7447 & 0.9220 & 0.6738 & 0.7872 & 0.6879 & 0.6879 & 0.9504 & 0.1064 \\
+lora & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 \\
+low\_t & 0.4085 & 0.0000 & 0.3944 & 0.2676 & 0.9296 & 0.8873 & 0.9859 & 0.1972 & 0.6761 & 0.3732 & 0.0000 & 0.0141 & 0.0845 & 0.6408 & 0.7465 & 0.9155 & 0.7254 & 0.7958 & 0.6901 & 0.7254 & 0.9507 & 0.1197 \\
+mixed\_replay & 0.5391 & 0.0000 & 0.4141 & 0.4062 & 0.9141 & 0.9141 & 1.0000 & 0.2266 & 0.6484 & 0.4219 & 0.0000 & 0.0078 & 0.1875 & 0.6875 & 0.7656 & 0.9297 & 0.6875 & 0.7969 & 0.6719 & 0.7187 & 0.9922 & 0.0547 \\
+normalized\_adv & 0.0840 & 0.0000 & 0.1849 & 0.0084 & 1.0000 & 0.3025 & 0.9580 & 0.0924 & 0.6134 & 0.2437 & 0.0000 & 0.0000 & 0.0000 & 0.0252 & 0.0672 & 0.3445 & 0.1176 & 0.2017 & 0.9412 & 0.1261 & 0.4538 & 0.0000 \\
+reward\_filtering & 0.4317 & 0.0000 & 0.3453 & 0.4029 & 0.9065 & 0.9065 & 0.9928 & 0.2446 & 0.6906 & 0.3022 & 0.0000 & 0.0504 & 0.1367 & 0.7122 & 0.7986 & 0.9353 & 0.6978 & 0.8417 & 0.6906 & 0.6259 & 0.9712 & 0.0863 \\
+reward\_model & 0.4615 & 0.0000 & 0.3427 & 0.3427 & 0.9021 & 0.8951 & 1.0000 & 0.2028 & 0.6643 & 0.3427 & 0.0000 & 0.0140 & 0.1538 & 0.6993 & 0.6923 & 0.9301 & 0.6154 & 0.7902 & 0.7063 & 0.6643 & 0.9580 & 0.0699 \\
+running\_stats & 0.5074 & 0.0000 & 0.3676 & 0.4118 & 0.8897 & 0.9044 & 0.9926 & 0.1838 & 0.6397 & 0.3750 & 0.0000 & 0.0368 & 0.1471 & 0.6838 & 0.7279 & 0.9412 & 0.7500 & 0.8088 & 0.6250 & 0.6103 & 0.9779 & 0.0809 \\
+t\_curriculum & 0.4397 & 0.0000 & 0.3972 & 0.3546 & 0.9149 & 0.8865 & 0.9929 & 0.2553 & 0.6454 & 0.3050 & 0.0000 & 0.0213 & 0.1277 & 0.6454 & 0.7021 & 0.9220 & 0.6738 & 0.8298 & 0.6667 & 0.6950 & 0.9645 & 0.0496 \\
+trust\_region\_kl & 0.4074 & 0.0000 & 0.4000 & 0.2815 & 0.8889 & 0.9333 & 1.0000 & 0.2296 & 0.6815 & 0.3630 & 0.0000 & 0.0000 & 0.0815 & 0.7111 & 0.7630 & 0.9481 & 0.7481 & 0.8000 & 0.6519 & 0.7333 & 1.0000 & 0.1185 \\
+\bottomrule
+\end{tabular}
+\end{table}
\ No newline at end of file
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/repr_drift.csv b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/repr_drift.csv
new file mode 100644
index 0000000000000000000000000000000000000000..728d521c522d389f76e48aac0b5d306c9d1c2998
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/repr_drift.csv
@@ -0,0 +1,26 @@
+Method,KL_mean,KL_low_t,KL_mid_t,KL_high_t
+trust_region_kl,0.064425,0.079172,0.064176,0.046738
+reward_filtering,0.151835,0.164997,0.151723,0.143443
+kl_penalty,0.167528,0.182938,0.169007,0.164755
+mixed_replay,0.173193,0.17947,0.169166,0.165786
+low_t,0.187926,0.196443,0.180127,0.193411
+llrd,0.188783,0.194401,0.183651,0.181313
+entropy_bonus,0.189308,0.198052,0.186509,0.196232
+ewc,0.194954,0.214022,0.190065,0.183393
+gradient_surgery,0.209463,0.216535,0.204606,0.200818
+baseline_rl,0.2096,0.216757,0.204743,0.200695
+action_diversity,0.209855,0.217063,0.205008,0.201086
+running_stats,0.21801,0.223517,0.214218,0.215939
+advantage_clip,0.219086,0.226308,0.216805,0.217097
+bc_wins,0.219727,0.227315,0.217366,0.216954
+reward_model,0.222742,0.231557,0.217179,0.216113
+t_curriculum,0.225474,0.223155,0.21883,0.244676
+layer_ablation_top2,9.672604,8.604239,9.558239,10.714105
+layer_ablation_top1,13.324247,13.311382,13.21699,13.340172
+layer_ablation_top3,16.208744,16.479584,15.960442,16.609873
+ffn_only,23.891375,22.415035,23.823528,25.39563
+normalized_adv,63.2407,63.106873,64.224091,61.898434
+attention_only,90.191254,90.413528,90.065102,90.144936
+frozen_backbone,115.5261,115.571564,115.430725,115.766418
+head_only,115.826111,115.743759,115.716293,116.180679
+lora,1679591.0,1673381.25,1672479.625,1676199.75
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/repr_drift.tex b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/repr_drift.tex
new file mode 100644
index 0000000000000000000000000000000000000000..2415f42618eb3625e9c19f8a925806e9d3961984
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/repr_drift.tex
@@ -0,0 +1,36 @@
+\begin{table}[htbp]
+\centering
+\caption{Representation drift (KL divergence) at final iteration.}
+\label{tab:repr_drift}
+\begin{tabular}{lrrrr}
+\toprule
+\textbf{Method} & \textbf{KL\_mean} & \textbf{KL\_low\_t} & \textbf{KL\_mid\_t} & \textbf{KL\_high\_t} \\
+\midrule
+trust\_region\_kl & 0.0644 & 0.0792 & 0.0642 & 0.0467 \\
+reward\_filtering & 0.1518 & 0.1650 & 0.1517 & 0.1434 \\
+kl\_penalty & 0.1675 & 0.1829 & 0.1690 & 0.1648 \\
+mixed\_replay & 0.1732 & 0.1795 & 0.1692 & 0.1658 \\
+low\_t & 0.1879 & 0.1964 & 0.1801 & 0.1934 \\
+llrd & 0.1888 & 0.1944 & 0.1837 & 0.1813 \\
+entropy\_bonus & 0.1893 & 0.1981 & 0.1865 & 0.1962 \\
+ewc & 0.1950 & 0.2140 & 0.1901 & 0.1834 \\
+gradient\_surgery & 0.2095 & 0.2165 & 0.2046 & 0.2008 \\
+baseline\_rl & 0.2096 & 0.2168 & 0.2047 & 0.2007 \\
+action\_diversity & 0.2099 & 0.2171 & 0.2050 & 0.2011 \\
+running\_stats & 0.2180 & 0.2235 & 0.2142 & 0.2159 \\
+advantage\_clip & 0.2191 & 0.2263 & 0.2168 & 0.2171 \\
+bc\_wins & 0.2197 & 0.2273 & 0.2174 & 0.2170 \\
+reward\_model & 0.2227 & 0.2316 & 0.2172 & 0.2161 \\
+t\_curriculum & 0.2255 & 0.2232 & 0.2188 & 0.2447 \\
+layer\_ablation\_top2 & 9.6726 & 8.6042 & 9.5582 & 10.7141 \\
+layer\_ablation\_top1 & 13.3242 & 13.3114 & 13.2170 & 13.3402 \\
+layer\_ablation\_top3 & 16.2087 & 16.4796 & 15.9604 & 16.6099 \\
+ffn\_only & 23.8914 & 22.4150 & 23.8235 & 25.3956 \\
+normalized\_adv & 63.2407 & 63.1069 & 64.2241 & 61.8984 \\
+attention\_only & 90.1913 & 90.4135 & 90.0651 & 90.1449 \\
+frozen\_backbone & 115.5261 & 115.5716 & 115.4307 & 115.7664 \\
+head\_only & 115.8261 & 115.7438 & 115.7163 & 116.1807 \\
+lora & 1679591.0000 & 1673381.2500 & 1672479.6250 & 1676199.7500 \\
+\bottomrule
+\end{tabular}
+\end{table}
\ No newline at end of file
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/t_distribution.csv b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/t_distribution.csv
new file mode 100644
index 0000000000000000000000000000000000000000..6d6dcf3dc7f5e371e9ef2d2ca54105b2ac60d7b6
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/t_distribution.csv
@@ -0,0 +1,26 @@
+Method,HighLow_Ratio,LowHigh_Cos_Sim,Dominant_Regime
+baseline_rl,0.08,0.0888,low-t
+kl_penalty,0.092,0.0543,low-t
+ewc,0.08,0.0666,low-t
+llrd,0.076,0.0818,low-t
+lora,0.328,0.9819,low-t
+mixed_replay,0.081,-0.0058,low-t
+trust_region_kl,0.099,0.06,low-t
+t_curriculum,0.098,0.0614,low-t
+entropy_bonus,0.092,0.0534,low-t
+gradient_surgery,0.08,0.0881,low-t
+advantage_clip,0.078,0.0912,low-t
+normalized_adv,0.303,0.6614,low-t
+bc_wins,0.081,0.0876,low-t
+low_t,0.103,0.0652,low-t
+frozen_backbone,0.313,0.9871,low-t
+head_only,0.314,0.9874,low-t
+attention_only,0.307,0.9822,low-t
+ffn_only,0.305,0.9569,low-t
+layer_ablation_top1,0.302,0.9718,low-t
+layer_ablation_top2,0.326,0.9374,low-t
+layer_ablation_top3,0.325,0.8966,low-t
+reward_filtering,0.075,0.0609,low-t
+running_stats,0.083,0.0936,low-t
+action_diversity,0.08,0.0895,low-t
+reward_model,0.085,0.0771,low-t
diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/t_distribution.tex b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/t_distribution.tex
new file mode 100644
index 0000000000000000000000000000000000000000..77c802804bfc029bae98d328bb5862fc77b0f208
--- /dev/null
+++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/t_distribution.tex
@@ -0,0 +1,36 @@
+\begin{table}[htbp]
+\centering
+\caption{Timestep distribution analysis.}
+\label{tab:t_dist}
+\begin{tabular}{lrrr}
+\toprule
+\textbf{Method} & \textbf{HighLow\_Ratio} & \textbf{LowHigh\_Cos\_Sim} & \textbf{Dominant\_Regime} \\
+\midrule
+baseline\_rl & 0.0800 & 0.0888 & low-t \\
+kl\_penalty & 0.0920 & 0.0543 & low-t \\
+ewc & 0.0800 & 0.0666 & low-t \\
+llrd & 0.0760 & 0.0818 & low-t \\
+lora & 0.3280 & 0.9819 & low-t \\
+mixed\_replay & 0.0810 & -0.0058 & low-t \\
+trust\_region\_kl & 0.0990 & 0.0600 & low-t \\
+t\_curriculum & 0.0980 & 0.0614 & low-t \\
+entropy\_bonus & 0.0920 & 0.0534 & low-t \\
+gradient\_surgery & 0.0800 & 0.0881 & low-t \\
+advantage\_clip & 0.0780 & 0.0912 & low-t \\
+normalized\_adv & 0.3030 & 0.6614 & low-t \\
+bc\_wins & 0.0810 & 0.0876 & low-t \\
+low\_t & 0.1030 & 0.0652 & low-t \\
+frozen\_backbone & 0.3130 & 0.9871 & low-t \\
+head\_only & 0.3140 & 0.9874 & low-t \\
+attention\_only & 0.3070 & 0.9822 & low-t \\
+ffn\_only & 0.3050 & 0.9569 & low-t \\
+layer\_ablation\_top1 & 0.3020 & 0.9718 & low-t \\
+layer\_ablation\_top2 & 0.3260 & 0.9374 & low-t \\
+layer\_ablation\_top3 & 0.3250 & 0.8966 & low-t \\
+reward\_filtering & 0.0750 & 0.0609 & low-t \\
+running\_stats & 0.0830 & 0.0936 & low-t \\
+action\_diversity & 0.0800 & 0.0895 & low-t \\
+reward\_model & 0.0850 & 0.0771 & low-t \\
+\bottomrule
+\end{tabular}
+\end{table}
\ No newline at end of file
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000000000000000000000000000000000000..caa86d14bc3ca77c7861e17ffd67e016452e5942
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,25 @@
+[project]
+name = "craftax-remdm-planner"
+version = "0.1.0"
+description = "Add your description here"
+readme = "README.md"
+requires-python = ">=3.12"
+dependencies = [
+ "chex>=0.1.91",
+ "craftax>=1.5.0",
+ "distrax>=0.1.7",
+ "flax>=0.12.6",
+ "huggingface-hub>=1.9.1",
+ "jax>=0.9.2",
+ "matplotlib>=3.10.8",
+ "numpy>=2.4.4",
+ "optax>=0.2.8",
+ "orbax>=0.1.9",
+ "orjson>=3.11.8",
+ "polars>=1.39.3",
+ "pyyaml>=6.0.3",
+ "wandb>=0.25.1",
+]
+
+[project.optional-dependencies]
+cuda = ["jax[cuda13]>=0.9.2"]
diff --git a/src/diffusion/__init__.py b/src/diffusion/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/diffusion/forward.py b/src/diffusion/forward.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba91b733c9162d2261dc6adef6fd790159c288be
--- /dev/null
+++ b/src/diffusion/forward.py
@@ -0,0 +1,28 @@
+"""Forward process: q(z_t | x_0) by independent per-token masking."""
+
+from __future__ import annotations
+
+import jax
+import jax.numpy as jnp
+
+
+def forward_process(
+ rng: jax.Array,
+ x_0: jnp.ndarray,
+ alpha_t: jnp.ndarray,
+ mask_id: int,
+) -> jnp.ndarray:
+ """Sample z_t ~ q(z_t | x_0). Each token stays with prob alpha_t, else MASK.
+
+ Args:
+ rng: PRNG key.
+ x_0: [B, H] int32, clean actions.
+ alpha_t: [B] or scalar, retention probability.
+ mask_id: MASK token index (= num_actions).
+
+ Returns:
+ z_t: [B, H] int32.
+ """
+ keep = jax.random.uniform(rng, shape=x_0.shape)
+ alpha_t = jnp.reshape(alpha_t, (-1, 1))
+ return jnp.where(keep < alpha_t, x_0, jnp.array(mask_id, dtype=x_0.dtype))
diff --git a/src/diffusion/loss.py b/src/diffusion/loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..0bfad20c465eee300694fbffed5a58e9df43a63a
--- /dev/null
+++ b/src/diffusion/loss.py
@@ -0,0 +1,122 @@
+"""MDLM ELBO loss for masked discrete diffusion training."""
+
+from __future__ import annotations
+from typing import Any, Callable, Optional
+
+import jax
+import jax.numpy as jnp
+
+from .forward import forward_process
+from .schedules import ScheduleFn
+
+ModelApplyFn = Callable[
+ [Any, jnp.ndarray, jnp.ndarray, jnp.ndarray, Optional[Any]], jnp.ndarray
+]
+
+_MAX_WEIGHT: float = 1000.0
+_EPS: float = 1e-5
+
+
+def compute_loss(
+ model_apply: ModelApplyFn,
+ params: Any,
+ rng: jax.Array,
+ x_0: jnp.ndarray,
+ obs: jnp.ndarray,
+ valid: jnp.ndarray,
+ num_actions: int,
+ schedule_fn: ScheduleFn,
+ schedule_deriv_fn: ScheduleFn,
+ sigma_t: float = 0.0,
+ label_smoothing: float = 0.0,
+ advantages: Optional[jnp.ndarray] = None,
+ t_min: float | jax.Array = _EPS,
+ t_max: float | jax.Array = 1.0,
+) -> tuple[jnp.ndarray, dict[str, jnp.ndarray]]:
+ """Continuous-time ELBO loss on masked positions only.
+
+ Args:
+ model_apply: fn(params, obs, z_t, t, rng) -> logits [B, H, V].
+ params: Model parameters.
+ rng: PRNG key.
+ x_0: [B, H] int32, ground-truth actions.
+ obs: [B, obs_dim] float32, observations.
+ valid: [B] bool/float, whether each sample is valid.
+ num_actions: Size of real action vocabulary.
+ schedule_fn: alpha(t).
+ schedule_deriv_fn: d(alpha)/dt (analytic).
+ sigma_t: Remasking correction for ELBO weight (0 = standard MDLM).
+ label_smoothing: Smoothing epsilon (0 = exact ELBO targets).
+ advantages: Optional [B] per-sample weights.
+ t_min: Lower bound for uniform t sampling (default: _EPS).
+ t_max: Upper bound for uniform t sampling (default: 1.0).
+
+ Returns:
+ (loss, info_dict).
+ """
+ B = x_0.shape[0]
+ mask_id = num_actions
+ rng, t_rng, mask_rng, drop_rng = jax.random.split(rng, 4)
+
+ # Sample t ~ U(t_min, t_max). Defaults give full ELBO; narrow range for ablations.
+ t = jax.random.uniform(t_rng, (B,), minval=t_min, maxval=t_max)
+ alpha_t = schedule_fn(t)
+
+ # Analytic loss weight: w(t) = (1 - sigma) * (-d(alpha)/dt) / (1 - alpha(t))
+ neg_alpha_dot = -schedule_deriv_fn(t) # positive quantity
+ weight = (1.0 - sigma_t) * neg_alpha_dot / jnp.maximum(1.0 - alpha_t, _EPS)
+ weight = jnp.minimum(weight, _MAX_WEIGHT)
+
+ # Forward noise
+ z_t = forward_process(mask_rng, x_0, alpha_t, mask_id)
+
+ # Model prediction
+ logits = model_apply(params, obs, z_t, t, drop_rng) # [B, H, V]
+
+ # Cross-entropy on valid masked positions
+ is_masked = (z_t == mask_id).astype(jnp.float32) # [B, H]
+ valid_masked = is_masked * valid[:, None].astype(jnp.float32) # [B, H]
+
+ targets = jax.nn.one_hot(x_0, num_actions)
+ if label_smoothing > 0:
+ targets = (1.0 - label_smoothing) * targets + label_smoothing / num_actions
+
+ log_probs = jax.nn.log_softmax(logits, axis=-1)
+ ce = -jnp.sum(targets * log_probs, axis=-1) # [B, H]
+
+ n_masked = jnp.maximum(valid_masked.sum(axis=-1), 1.0) # [B]
+ per_sample = weight * (ce * valid_masked).sum(axis=-1) / n_masked
+
+ if advantages is not None:
+ per_sample = per_sample * jax.lax.stop_gradient(advantages)
+
+ loss = jnp.mean(per_sample)
+
+ # Diagnostics
+ preds = jnp.argmax(logits, axis=-1)
+ correct = (preds == x_0).astype(jnp.float32)
+ acc = jnp.sum(correct * valid_masked) / jnp.maximum(valid_masked.sum(), 1.0)
+
+ t_bins = jnp.array([0.33, 0.66])
+ t_lo = (t < t_bins[0])[:, None]
+ t_mi = ((t >= t_bins[0]) & (t <= t_bins[1]))[:, None]
+ t_hi = (t > t_bins[1])[:, None]
+
+ def _binned_acc(mask):
+ m = valid_masked * mask
+ return jnp.sum(correct * m) / jnp.maximum(m.sum(), 1.0)
+
+ info = {
+ "loss": loss,
+ "unweighted_loss": jnp.mean((ce * valid_masked).sum(axis=-1) / n_masked),
+ "mean_t": jnp.mean(t),
+ "frac_masked": jnp.mean(is_masked),
+ "accuracy": acc,
+ "acc_t_low": _binned_acc(t_lo),
+ "acc_t_mid": _binned_acc(t_mi),
+ "acc_t_high": _binned_acc(t_hi),
+ }
+ if advantages is not None:
+ info["adv_mean"] = jnp.mean(advantages)
+ info["adv_std"] = jnp.std(advantages)
+ return loss, info
diff --git a/src/diffusion/sampling.py b/src/diffusion/sampling.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1fb3532041f93bd4595bc16d19f764b568651df
--- /dev/null
+++ b/src/diffusion/sampling.py
@@ -0,0 +1,312 @@
+"""Reverse diffusion sampling with ReMDM remasking (Wang et al.).
+
+Strategies (Section 4.1):
+ rescale: sigma = eta * sigma_max
+ cap: sigma = min(eta, sigma_max)
+ conf: per-token confidence-based remasking
+
+Loop mode (Section 4.2, Algorithm 3):
+ Phase 1: standard MDLM decode, t in [1, t_on]
+ Phase 2: constant alpha(t_on), remasking active
+ Phase 3: standard MDLM decode, t in [t_off, 0]
+"""
+
+from __future__ import annotations
+from typing import Any, Callable, Optional
+
+import jax
+import jax.numpy as jnp
+
+from .schedules import ScheduleFn
+
+ModelApplyFn = Callable[
+ [Any, jnp.ndarray, jnp.ndarray, jnp.ndarray, Optional[Any]], jnp.ndarray
+]
+
+
+# ---------------------------------------------------------------------------
+# Remasking sigma computation
+# ---------------------------------------------------------------------------
+
+def _sigma_max(alpha_t: jnp.ndarray, alpha_s: jnp.ndarray) -> jnp.ndarray:
+ """sigma_max = min(1, (1 - alpha_s) / alpha_t). [Eq. 7]"""
+ return jnp.minimum(1.0, (1.0 - alpha_s) / jnp.maximum(alpha_t, 1e-8))
+
+
+def sigma_rescale(alpha_t, alpha_s, eta):
+ return eta * _sigma_max(alpha_t, alpha_s)
+
+
+def sigma_cap(alpha_t, alpha_s, eta):
+ return jnp.minimum(eta, _sigma_max(alpha_t, alpha_s))
+
+
+def sigma_conf(alpha_t, alpha_s, eta, psi, is_unmasked):
+ """Per-token confidence remasking. Safe against all-masked rows."""
+ base = eta * _sigma_max(alpha_t, alpha_s)
+ any_unmasked = jnp.any(is_unmasked, axis=-1, keepdims=True)
+ neg_psi = jnp.where(is_unmasked, -psi, -jnp.inf)
+ safe_neg_psi = jnp.where(any_unmasked, neg_psi, 0.0)
+ eta_conf = jax.nn.softmax(safe_neg_psi, axis=-1)
+ return jnp.where(is_unmasked, eta_conf * base, 0.0)
+
+
+_SIGMA_FNS = {
+ "rescale": lambda a_t, a_s, eta, *_: sigma_rescale(a_t, a_s, eta),
+ "cap": lambda a_t, a_s, eta, *_: sigma_cap(a_t, a_s, eta),
+ "conf": lambda a_t, a_s, eta, psi, unm: sigma_conf(a_t, a_s, eta, psi, unm),
+}
+
+
+# ---------------------------------------------------------------------------
+# Decoding helpers
+# ---------------------------------------------------------------------------
+
+def _nucleus_sample(rng, logits, top_p):
+ """Top-p sampling from [B, H, V] logits -> [B, H] int32."""
+ probs = jax.nn.softmax(logits, axis=-1)
+ idx = jnp.argsort(-probs, axis=-1)
+ sorted_p = jnp.take_along_axis(probs, idx, axis=-1)
+ cum = jnp.cumsum(sorted_p, axis=-1)
+
+ cutoff = cum - sorted_p
+ sorted_p = jnp.where(cutoff >= top_p, 0.0, sorted_p)
+ sorted_p = sorted_p / jnp.maximum(sorted_p.sum(axis=-1, keepdims=True), 1e-12)
+
+ B, H, V = logits.shape
+ flat = sorted_p.reshape(B * H, V)
+ tokens = jax.random.categorical(rng, jnp.log(flat + 1e-12)).reshape(B, H)
+ return jnp.take_along_axis(idx, tokens[..., None], axis=-1).squeeze(-1)
+
+
+def _decode(rng, logits, temperature, top_p):
+ """Sample tokens from logits. Argmax only when temperature <= 0."""
+ if top_p is not None:
+ return _nucleus_sample(rng, logits / jnp.maximum(temperature, 1e-8), top_p)
+ if temperature > 1e-8:
+ B, H, V = logits.shape
+ scaled = logits / temperature
+ return jax.random.categorical(rng, scaled.reshape(-1, V)).reshape(B, H)
+ return jnp.argmax(logits, axis=-1)
+
+
+# ---------------------------------------------------------------------------
+# Reverse sampling
+# ---------------------------------------------------------------------------
+
+def sample_plan(
+ model_apply: ModelApplyFn,
+ params: Any,
+ rng: jax.Array,
+ obs: jnp.ndarray,
+ num_actions: int,
+ plan_horizon: int,
+ num_steps: int,
+ schedule_fn: ScheduleFn,
+ remask_strategy: str = "cap",
+ eta: float = 0.5,
+ use_loop: bool = False,
+ t_on: float = 0.55,
+ t_off: float = 0.05,
+ temperature: float = 1.0,
+ top_p: Optional[float] = None,
+) -> jnp.ndarray:
+ """Generate an action plan via reverse diffusion with ReMDM remasking.
+
+ Returns:
+ actions: [B, H] int32.
+ """
+ B = obs.shape[0]
+ mask_id = num_actions
+ mask_val = jnp.array(mask_id, dtype=jnp.int32)
+
+ if remask_strategy not in _SIGMA_FNS:
+ raise ValueError(f"Unknown strategy {remask_strategy!r}. Options: {list(_SIGMA_FNS)}")
+ get_sigma = _SIGMA_FNS[remask_strategy]
+
+ # Phase allocation for loop mode
+ if use_loop:
+ f1, f3 = 1.0 - t_on, t_off
+ denom = f1 + f3 + (t_on - t_off)
+ n1 = max(int(round(num_steps * f1 / denom)), 1)
+ n3 = max(int(round(num_steps * f3 / denom)), 1)
+ n2 = max(num_steps - n1 - n3, 1)
+ else:
+ n1, n2, n3 = num_steps, 0, 0
+
+ alpha_loop = schedule_fn(jnp.array(t_on))
+
+ z_init = jnp.full((B, plan_horizon), mask_id, dtype=jnp.int32)
+ psi_init = jnp.full((B, plan_horizon), jnp.inf)
+
+ # ------------------------------------------------------------------
+ # Core denoising step (ReMDM Eq. 6)
+ # ------------------------------------------------------------------
+ def _step(carry, _unused, t_val, alpha_t, alpha_s, sigma_on):
+ z, rng, psi = carry
+ rng, s_rng, u_rng, r_rng = jax.random.split(rng, 4)
+
+ t_inp = jnp.full((B,), t_val)
+ logits = model_apply(params, obs, z, t_inp, None)
+ x_hat = _decode(s_rng, logits, temperature, top_p)
+
+ is_masked = z == mask_id
+ is_unmasked = ~is_masked
+
+ sigma = get_sigma(alpha_t, alpha_s, eta, psi, is_unmasked)
+ sigma = jnp.broadcast_to(sigma, z.shape)
+ sigma = jnp.where(sigma_on, sigma, 0.0)
+
+ # Masked -> unmask probability
+ denom = jnp.maximum(1.0 - alpha_t, 1e-8)
+ p_unmask = jnp.clip((alpha_s - (1.0 - sigma) * alpha_t) / denom, 0.0, 1.0)
+
+ do_unmask = is_masked & (jax.random.uniform(u_rng, z.shape) < p_unmask)
+ do_remask = is_unmasked & (jax.random.uniform(r_rng, z.shape) < sigma)
+
+ z_new = jnp.where(do_unmask, x_hat, z)
+ z_new = jnp.where(do_remask, mask_val, z_new)
+
+ # Update confidence history
+ probs = jax.nn.softmax(logits, axis=-1)
+ decode_prob = jnp.take_along_axis(probs, x_hat[..., None], axis=-1).squeeze(-1)
+ psi_new = jnp.where(do_unmask, decode_prob, psi)
+ psi_new = jnp.where(do_remask, jnp.inf, psi_new)
+
+ return (z_new, rng, psi_new), None
+
+ # ------------------------------------------------------------------
+ # Phase functions
+ # ------------------------------------------------------------------
+ def _phase1_step(carry, idx):
+ t = 1.0 - idx * (1.0 - t_on) / n1
+ s = jnp.maximum(1.0 - (idx + 1) * (1.0 - t_on) / n1, t_on)
+ return _step(carry, idx, t, schedule_fn(t), schedule_fn(s), False)
+
+ def _phase2_step(carry, idx):
+ return _step(carry, idx, t_on, alpha_loop, alpha_loop, True)
+
+ def _phase3_step(carry, idx):
+ t = t_off - idx * t_off / n3
+ s = jnp.maximum(t_off - (idx + 1) * t_off / n3, 0.0)
+ return _step(carry, idx, t, schedule_fn(t), schedule_fn(s), False)
+
+ def _simple_step(carry, idx):
+ t = (num_steps - idx) / num_steps
+ s = jnp.maximum((num_steps - idx - 1) / num_steps, 0.0)
+ return _step(carry, idx, t, schedule_fn(t), schedule_fn(s), True)
+
+ # ------------------------------------------------------------------
+ # Run
+ # ------------------------------------------------------------------
+ carry = (z_init, rng, psi_init)
+
+ if use_loop:
+ carry, _ = jax.lax.scan(_phase1_step, carry, jnp.arange(n1))
+ carry, _ = jax.lax.scan(_phase2_step, carry, jnp.arange(n2))
+ if n3 > 0:
+ carry, _ = jax.lax.scan(_phase3_step, carry, jnp.arange(n3))
+ else:
+ carry, _ = jax.lax.scan(_simple_step, carry, jnp.arange(num_steps))
+
+ z_final = carry[0]
+
+ # Final greedy cleanup for any remaining masks
+ final_logits = model_apply(params, obs, z_final, jnp.zeros((B,)), None)
+ fallback = jnp.argmax(final_logits, axis=-1)
+ return jnp.where(z_final == mask_id, fallback, z_final)
+
+
+# ---------------------------------------------------------------------------
+# Inpainting sampler (MPC / historical prefix)
+# ---------------------------------------------------------------------------
+
+def sample_plan_inpainting(
+ apply_fn: ModelApplyFn,
+ params: Any,
+ rng: jax.Array,
+ obs: jnp.ndarray,
+ history: jnp.ndarray,
+ hist_len: jnp.ndarray,
+ num_actions: int,
+ plan_horizon: int,
+ diffusion_steps: int,
+ temperature: float,
+ top_p: Optional[float],
+) -> jnp.ndarray:
+ """Diffusion sampling with a locked historical prefix (inpainting).
+
+ Positions ``0 .. hist_len[b] - 1`` are fixed to the values in ``history``
+ for each batch element ``b``; the remainder are diffused freely.
+
+ Args:
+ apply_fn: Model apply closure (eval mode, no dropout).
+ params: Model parameter pytree.
+ rng: PRNG key.
+ obs: ``[B, obs_dim]`` conditioning observations.
+ history: ``[B, plan_horizon]`` int32 prefix of executed actions.
+ hist_len: ``[B]`` int32 number of valid prefix tokens per element.
+ num_actions: Size of the real action vocabulary (mask token = ``num_actions``).
+ plan_horizon: Total sequence length.
+ diffusion_steps: Number of denoising iterations.
+ temperature: Softmax temperature for token sampling.
+ top_p: Nucleus-sampling threshold; ``None`` disables nucleus filtering.
+
+ Returns:
+ ``[B, plan_horizon]`` int32 completed action plan.
+ """
+ B = obs.shape[0]
+ mask_id = num_actions
+
+ def _step(carry, step):
+ seq, rng = carry
+ rng, model_rng, sample_rng, remask_rng = jax.random.split(rng, 4)
+
+ ratio = step / diffusion_steps
+ t_tensor = jnp.full((B,), 1.0 - ratio)
+ logits = apply_fn(params, obs, seq, t_tensor, model_rng) / jnp.maximum(temperature, 1e-8)
+
+ # Optional nucleus filtering
+ if top_p is not None:
+ probs = jax.nn.softmax(logits, axis=-1)
+ sorted_idx = jnp.argsort(-probs, axis=-1)
+ sorted_p = jnp.take_along_axis(probs, sorted_idx, axis=-1)
+ cutoff = jnp.cumsum(sorted_p, axis=-1) - sorted_p
+ inv_idx = jnp.argsort(sorted_idx, axis=-1)
+ nucleus_mask = jnp.take_along_axis(cutoff >= top_p, inv_idx, axis=-1)
+ logits = jnp.where(nucleus_mask, -jnp.inf, logits)
+
+ preds = jax.random.categorical(sample_rng, logits, axis=-1)
+ conf = jnp.take_along_axis(
+ jax.nn.softmax(logits, axis=-1), preds[..., None], axis=-1,
+ ).squeeze(-1)
+
+ # Keep top-(ratio * H) most confident predictions unmasked
+ num_unmask = jnp.maximum(1, (plan_horizon * ratio).astype(jnp.int32))
+ sorted_conf = jnp.sort(conf, axis=-1)[..., ::-1]
+ thresh = sorted_conf[jnp.arange(B), num_unmask - 1]
+ seq_new = jnp.where(conf < thresh[:, None], mask_id, preds)
+
+ # Light ReMDM-style remasking
+ remask_prob = 0.15 * (1.0 - ratio)
+ do_remask = (
+ (jax.random.uniform(remask_rng, seq_new.shape) < remask_prob)
+ & (seq_new != mask_id)
+ )
+ seq_new = jnp.where(do_remask, mask_id, seq_new)
+
+ # Lock historical prefix
+ pos = jnp.broadcast_to(jnp.arange(plan_horizon)[None, :], (B, plan_horizon))
+ seq_new = jnp.where(pos < hist_len[:, None], history, seq_new)
+
+ return (seq_new, rng), None
+
+ # Initialise: history locked, remainder fully masked
+ init_seq = jnp.full((B, plan_horizon), mask_id, dtype=jnp.int32)
+ pos = jnp.broadcast_to(jnp.arange(plan_horizon)[None, :], (B, plan_horizon))
+ init_seq = jnp.where(pos < hist_len[:, None], history, init_seq)
+
+ (final_seq, _), _ = jax.lax.scan(
+ _step, (init_seq, rng), jnp.arange(1, diffusion_steps + 1),
+ )
+ return final_seq
diff --git a/src/diffusion/schedules.py b/src/diffusion/schedules.py
new file mode 100644
index 0000000000000000000000000000000000000000..f7e815322cf1d55428d0981b959963c106706ef2
--- /dev/null
+++ b/src/diffusion/schedules.py
@@ -0,0 +1,35 @@
+"""Noise schedules for masked discrete diffusion.
+
+alpha(t) is the retention probability: alpha(0)=1 (clean), alpha(1)=0 (fully masked).
+"""
+
+from __future__ import annotations
+from typing import Callable
+
+import jax.numpy as jnp
+
+ScheduleFn = Callable[[jnp.ndarray], jnp.ndarray]
+
+
+def linear_schedule(t: jnp.ndarray) -> jnp.ndarray:
+ """alpha(t) = 1 - t. Default in MDLM / ReMDM."""
+ return 1.0 - t
+
+
+def linear_schedule_deriv(t: jnp.ndarray) -> jnp.ndarray:
+ return jnp.full_like(t, -1.0)
+
+
+def cosine_schedule(t: jnp.ndarray) -> jnp.ndarray:
+ """alpha(t) = cos(pi * t / 2)."""
+ return jnp.cos(t * jnp.pi / 2.0)
+
+
+def cosine_schedule_deriv(t: jnp.ndarray) -> jnp.ndarray:
+ return -(jnp.pi / 2.0) * jnp.sin(t * jnp.pi / 2.0)
+
+
+SCHEDULE_MAP: dict[str, tuple[ScheduleFn, ScheduleFn]] = {
+ "linear": (linear_schedule, linear_schedule_deriv),
+ "cosine": (cosine_schedule, cosine_schedule_deriv),
+}
diff --git a/src/envs/__init__.py b/src/envs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b7835ea3776d9f016570b56fc9723a55b0f9920
--- /dev/null
+++ b/src/envs/__init__.py
@@ -0,0 +1,15 @@
+"""Environment wrappers for the ReMDM planner."""
+
+from src.envs.wrappers import (
+ DiscreteTokenizationWrapper,
+ OfflineTrajectoryWrapper,
+ PlannerWrapper,
+ SequenceHistoryWrapper,
+)
+
+__all__ = [
+ "DiscreteTokenizationWrapper",
+ "OfflineTrajectoryWrapper",
+ "PlannerWrapper",
+ "SequenceHistoryWrapper",
+]
diff --git a/src/envs/wrappers.py b/src/envs/wrappers.py
new file mode 100644
index 0000000000000000000000000000000000000000..f61102c268d864a9af60d3a8024cce92dcf83b78
--- /dev/null
+++ b/src/envs/wrappers.py
@@ -0,0 +1,403 @@
+"""Project-specific Gymnax wrappers for the ReMDM planner."""
+
+from __future__ import annotations
+
+from functools import partial
+from typing import Any, Callable, Tuple, Union
+
+import chex
+import jax
+import jax.numpy as jnp
+from flax import struct
+
+from Craftax_Baselines.wrappers import GymnaxWrapper
+
+
+# =============================================================================
+# SequenceHistoryWrapper
+# =============================================================================
+
+
+@struct.dataclass
+class SequenceHistoryState:
+ env_state: Any
+ obs_history: chex.Array # [history_len, *obs_shape]
+ act_history: chex.Array # [history_len] int32
+
+
+class SequenceHistoryWrapper(GymnaxWrapper):
+ """Augments env state with a sliding window of past observations and actions.
+
+ After each step the histories satisfy::
+
+ obs_history[-1] = current observation
+ act_history[i] = action taken from obs_history[i] to reach obs_history[i+1]
+
+ The wrapper returns the current observation unchanged; the sequence context is
+ accessed via ``state.obs_history`` and ``state.act_history`` in the training loop.
+
+ Place this as the **innermost** wrapper (before AutoReset / LogWrapper) so that
+ episode boundaries trigger a proper history reset via the auto-reset mechanism.
+
+ Args:
+ env: Single Gymnax environment.
+ history_len: Number of past timesteps to keep (including current).
+ obs_shape: Shape of a single observation, e.g. ``(obs_dim,)``.
+ """
+
+ def __init__(self, env: Any, history_len: int, obs_shape: Tuple[int, ...]) -> None:
+ super().__init__(env)
+ self.history_len = history_len
+ self.obs_shape = obs_shape
+
+ @partial(jax.jit, static_argnums=(0, 2))
+ def reset(
+ self, key: chex.PRNGKey, params: Any = None
+ ) -> Tuple[chex.Array, SequenceHistoryState]:
+ obs, env_state = self._env.reset(key, params)
+ obs_history = jnp.tile(
+ obs[None], [self.history_len] + [1] * len(self.obs_shape)
+ )
+ act_history = jnp.zeros(self.history_len, dtype=jnp.int32)
+ state = SequenceHistoryState(
+ env_state=env_state,
+ obs_history=obs_history,
+ act_history=act_history,
+ )
+ return obs, state
+
+ @partial(jax.jit, static_argnums=(0, 4))
+ def step(
+ self,
+ key: chex.PRNGKey,
+ state: SequenceHistoryState,
+ action: Union[int, float],
+ params: Any = None,
+ ) -> Tuple[chex.Array, SequenceHistoryState, chex.Array, chex.Array, Any]:
+ obs, env_state, reward, done, info = self._env.step(
+ key, state.env_state, action, params
+ )
+ act_history = jnp.roll(state.act_history, -1, axis=0).at[-1].set(action)
+ obs_history = jnp.roll(state.obs_history, -1, axis=0).at[-1].set(obs)
+ new_state = SequenceHistoryState(
+ env_state=env_state,
+ obs_history=obs_history,
+ act_history=act_history,
+ )
+ return obs, new_state, reward, done, info
+
+
+# =============================================================================
+# DiscreteTokenizationWrapper
+# =============================================================================
+
+
+class DiscreteTokenizationWrapper(GymnaxWrapper):
+ """Quantizes continuous observations into discrete token indices.
+
+ Each observation element is mapped to one of ``n_bins`` integer tokens using
+ uniform binning between ``obs_min`` and ``obs_max``.
+
+ The returned observation dtype is int32 with values in ``[0, n_bins - 1]``.
+
+ Args:
+ env: Gymnax environment (or wrapper).
+ n_bins: Number of discrete bins per observation element.
+ obs_min: Per-element lower bound, shape matching the observation.
+ obs_max: Per-element upper bound, shape matching the observation.
+ """
+
+ def __init__(
+ self,
+ env: Any,
+ n_bins: int,
+ obs_min: jnp.ndarray,
+ obs_max: jnp.ndarray,
+ ) -> None:
+ super().__init__(env)
+ self.n_bins = n_bins
+ self.obs_min = obs_min
+ self.obs_max = obs_max
+
+ def _tokenize(self, obs: chex.Array) -> chex.Array:
+ obs_clipped = jnp.clip(obs, self.obs_min, self.obs_max)
+ normalized = (obs_clipped - self.obs_min) / (
+ self.obs_max - self.obs_min + 1e-8
+ )
+ tokens = jnp.floor(normalized * self.n_bins).astype(jnp.int32)
+ return jnp.clip(tokens, 0, self.n_bins - 1)
+
+ @partial(jax.jit, static_argnums=(0, 2))
+ def reset(
+ self, key: chex.PRNGKey, params: Any = None
+ ) -> Tuple[chex.Array, Any]:
+ obs, state = self._env.reset(key, params)
+ return self._tokenize(obs), state
+
+ @partial(jax.jit, static_argnums=(0, 4))
+ def step(
+ self,
+ key: chex.PRNGKey,
+ state: Any,
+ action: Union[int, float],
+ params: Any = None,
+ ) -> Tuple[chex.Array, Any, chex.Array, chex.Array, Any]:
+ obs, state, reward, done, info = self._env.step(key, state, action, params)
+ return self._tokenize(obs), state, reward, done, info
+
+
+# =============================================================================
+# PlannerWrapper
+# =============================================================================
+
+
+@struct.dataclass
+class PlannerState:
+ env_state: Any
+ current_plan: chex.Array # [num_envs, plan_horizon] int32
+ plan_step: int
+
+
+class PlannerWrapper(GymnaxWrapper):
+ """Manages the plan / replan cycle for a discrete diffusion planner.
+
+ Expected wrapper stack (inner -> outer)::
+
+ env -> LogWrapper -> AutoResetEnvWrapper -> BatchEnvWrapper
+ -> PlannerWrapper
+
+ The ``planner_apply_fn`` must have the signature::
+
+ fn(rng, model_params, obs) -> jnp.ndarray # [num_envs, plan_horizon] int32
+
+ Args:
+ env: Batched Gymnax environment (already handles num_envs).
+ num_envs: Number of parallel environments.
+ plan_horizon: Total number of actions the diffusion model outputs.
+ replan_every: Steps to execute before requesting a new plan (<= plan_horizon).
+ planner_apply_fn: Callable that invokes the diffusion model.
+ """
+
+ def __init__(
+ self,
+ env: Any,
+ num_envs: int,
+ plan_horizon: int,
+ replan_every: int,
+ planner_apply_fn: Callable[..., jnp.ndarray],
+ ) -> None:
+ super().__init__(env)
+ if replan_every > plan_horizon:
+ raise ValueError(
+ f"replan_every ({replan_every}) must be <= plan_horizon ({plan_horizon})"
+ )
+ self.num_envs = num_envs
+ self.plan_horizon = plan_horizon
+ self.replan_every = replan_every
+ self.planner_apply_fn = planner_apply_fn
+
+ @partial(jax.jit, static_argnums=(0, 2))
+ def reset(
+ self, key: chex.PRNGKey, params: Any = None
+ ) -> Tuple[chex.Array, PlannerState]:
+ obs, env_state = self._env.reset(key, params)
+ current_plan = jnp.zeros(
+ (self.num_envs, self.plan_horizon), dtype=jnp.int32
+ )
+ state = PlannerState(
+ env_state=env_state,
+ current_plan=current_plan,
+ plan_step=0,
+ )
+ return obs, state
+
+ @partial(jax.jit, static_argnums=(0,))
+ def step(
+ self,
+ key: chex.PRNGKey,
+ state: PlannerState,
+ last_obs: chex.Array,
+ model_params: Any,
+ env_params: Any = None,
+ ) -> Tuple[chex.Array, PlannerState, chex.Array, chex.Array, chex.Array, Any]:
+ """Step the environment using the diffusion plan.
+
+ Args:
+ key: PRNG key.
+ state: Current PlannerState.
+ last_obs: Most recent batched observation [num_envs, *obs_shape].
+ model_params: Parameters passed to planner_apply_fn.
+ env_params: Optional Gymnax environment params.
+
+ Returns:
+ (obs, state, action, reward, done, info)
+ """
+ key, plan_key, step_key = jax.random.split(key, 3)
+
+ current_plan = jax.lax.cond(
+ state.plan_step == 0,
+ lambda operand: self.planner_apply_fn(*operand),
+ lambda operand: state.current_plan,
+ (plan_key, model_params, last_obs),
+ )
+
+ action = current_plan[:, state.plan_step]
+
+ obs, env_state, reward, done, info = self._env.step(
+ step_key, state.env_state, action, env_params
+ )
+
+ new_plan_step = (state.plan_step + 1) % self.replan_every
+ new_state = PlannerState(
+ env_state=env_state,
+ current_plan=current_plan,
+ plan_step=new_plan_step,
+ )
+ return obs, new_state, action, reward, done, info
+
+
+# =============================================================================
+# OfflineTrajectoryWrapper
+# =============================================================================
+
+
+@struct.dataclass
+class TrajectoryBufferState:
+ env_state: Any
+ last_obs: Any # [*obs_shape]
+ buf_obs: Any # [max_size, *obs_shape]
+ buf_act: Any # [max_size] int32
+ buf_reward: Any # [max_size] float32
+ buf_done: Any # [max_size] bool
+ buf_next_obs: Any # [max_size, *obs_shape]
+ write_idx: Any # int32, wraps at max_size
+ num_valid: Any # int32, capped at max_size
+
+
+class OfflineTrajectoryWrapper(GymnaxWrapper):
+ """Accumulates transitions into a fixed-size circular replay buffer.
+
+ The buffer overwrites the oldest entries once full. Use ``sample_sequences``
+ to draw contiguous subsequences for training a sequence model like ReMDM.
+
+ Designed for a single environment; compose with ``BatchEnvWrapper`` *outside*
+ this wrapper to collect from multiple envs simultaneously.
+
+ Args:
+ env: Single Gymnax environment (or wrapper).
+ max_size: Maximum number of transitions to store.
+ obs_shape: Shape of a single observation, e.g. ``(obs_dim,)``.
+ """
+
+ def __init__(
+ self, env: Any, max_size: int, obs_shape: Tuple[int, ...]
+ ) -> None:
+ super().__init__(env)
+ self.max_size = max_size
+ self.obs_shape = obs_shape
+
+ def _empty_buffer(
+ self, env_state: Any, first_obs: chex.Array
+ ) -> TrajectoryBufferState:
+ return TrajectoryBufferState(
+ env_state=env_state,
+ last_obs=first_obs,
+ buf_obs=jnp.zeros(
+ (self.max_size, *self.obs_shape), dtype=jnp.float32
+ ),
+ buf_act=jnp.zeros(self.max_size, dtype=jnp.int32),
+ buf_reward=jnp.zeros(self.max_size, dtype=jnp.float32),
+ buf_done=jnp.zeros(self.max_size, dtype=jnp.bool_),
+ buf_next_obs=jnp.zeros(
+ (self.max_size, *self.obs_shape), dtype=jnp.float32
+ ),
+ write_idx=jnp.int32(0),
+ num_valid=jnp.int32(0),
+ )
+
+ @partial(jax.jit, static_argnums=(0, 2))
+ def reset(
+ self, key: chex.PRNGKey, params: Any = None
+ ) -> Tuple[chex.Array, TrajectoryBufferState]:
+ obs, env_state = self._env.reset(key, params)
+ state = self._empty_buffer(env_state, obs)
+ return obs, state
+
+ @partial(jax.jit, static_argnums=(0, 4))
+ def step(
+ self,
+ key: chex.PRNGKey,
+ state: TrajectoryBufferState,
+ action: Union[int, float],
+ params: Any = None,
+ ) -> Tuple[chex.Array, TrajectoryBufferState, chex.Array, chex.Array, Any]:
+ obs, env_state, reward, done, info = self._env.step(
+ key, state.env_state, action, params
+ )
+
+ idx = state.write_idx % self.max_size
+ buf_obs = state.buf_obs.at[idx].set(state.last_obs)
+ buf_act = state.buf_act.at[idx].set(action)
+ buf_reward = state.buf_reward.at[idx].set(reward)
+ buf_done = state.buf_done.at[idx].set(done)
+ buf_next_obs = state.buf_next_obs.at[idx].set(obs)
+
+ # Wrap write_idx at max_size to prevent unbounded growth / int32 overflow
+ new_write_idx = (state.write_idx + 1) % self.max_size
+ is_full = state.num_valid >= self.max_size
+ new_num_valid = jnp.where(is_full, self.max_size, state.num_valid + 1)
+
+ new_state = TrajectoryBufferState(
+ env_state=env_state,
+ last_obs=obs,
+ buf_obs=buf_obs,
+ buf_act=buf_act,
+ buf_reward=buf_reward,
+ buf_done=buf_done,
+ buf_next_obs=buf_next_obs,
+ write_idx=new_write_idx,
+ num_valid=new_num_valid,
+ )
+ return obs, new_state, reward, done, info
+
+ @partial(jax.jit, static_argnums=(0, 3, 4))
+ def sample_sequences(
+ self,
+ rng: chex.PRNGKey,
+ state: TrajectoryBufferState,
+ n_samples: int,
+ seq_len: int,
+ ) -> Tuple[
+ chex.Array, chex.Array, chex.Array, chex.Array, chex.Array
+ ]:
+ """Sample ``n_samples`` contiguous subsequences of length ``seq_len``.
+
+ Precondition: ``state.num_valid >= seq_len``.
+
+ Returns:
+ obs [n_samples, seq_len, *obs_shape]
+ act [n_samples, seq_len]
+ reward [n_samples, seq_len]
+ done [n_samples, seq_len]
+ next_obs [n_samples, seq_len, *obs_shape]
+ """
+ max_start = jnp.maximum(state.num_valid - seq_len, 1)
+ start_indices = jax.random.randint(
+ rng, shape=(n_samples,), minval=0, maxval=max_start
+ )
+
+ def gather_seq(
+ start_idx: jnp.ndarray,
+ ) -> Tuple[
+ chex.Array, chex.Array, chex.Array, chex.Array, chex.Array
+ ]:
+ indices = (start_idx + jnp.arange(seq_len)) % self.max_size
+ return (
+ state.buf_obs[indices],
+ state.buf_act[indices],
+ state.buf_reward[indices],
+ state.buf_done[indices],
+ state.buf_next_obs[indices],
+ )
+
+ return jax.vmap(gather_seq)(start_indices)
diff --git a/src/models/__init__.py b/src/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/models/denoiser.py b/src/models/denoiser.py
new file mode 100644
index 0000000000000000000000000000000000000000..972e122f5cd95d5317d3fb1c8062ed167a9a7221
--- /dev/null
+++ b/src/models/denoiser.py
@@ -0,0 +1,125 @@
+"""Denoising Transformer for masked discrete diffusion planning.
+
+Architecture: obs MLP encoder + sinusoidal time embedding + bidirectional
+transformer. Two prefix tokens (obs, time) precede the action sequence.
+"""
+
+from __future__ import annotations
+
+import numpy as np
+import jax.numpy as jnp
+import flax.linen as nn
+from flax.linen.initializers import constant, orthogonal
+
+_INIT = orthogonal(np.sqrt(2))
+_INIT_SMALL = orthogonal(0.01)
+_BIAS = constant(0.0)
+
+
+class SinusoidalPosEmbed(nn.Module):
+ """Sinusoidal embedding for continuous timesteps or integer positions."""
+
+ dim: int
+
+ @nn.compact
+ def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
+ half = self.dim // 2
+ freqs = jnp.exp(-jnp.log(10_000.0) * jnp.arange(half) / half)
+ angles = x[..., None] * freqs
+ emb = jnp.concatenate([jnp.sin(angles), jnp.cos(angles)], axis=-1)
+ if self.dim % 2 == 1:
+ emb = jnp.concatenate([emb, jnp.zeros_like(emb[..., :1])], axis=-1)
+ return emb
+
+
+class TransformerBlock(nn.Module):
+ """Pre-norm transformer: LN -> MHA -> res -> LN -> FFN -> res."""
+
+ d_model: int
+ n_heads: int
+ d_ff: int
+ dropout_rate: float = 0.1
+ deterministic: bool = True
+
+ @nn.compact
+ def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
+ h = nn.LayerNorm()(x)
+ h = nn.MultiHeadDotProductAttention(
+ num_heads=self.n_heads, kernel_init=_INIT, deterministic=self.deterministic,
+ )(h, h)
+ h = nn.Dropout(rate=self.dropout_rate, deterministic=self.deterministic)(h)
+ x = x + h
+
+ h = nn.LayerNorm()(x)
+ h = nn.Dense(self.d_ff, kernel_init=_INIT, bias_init=_BIAS)(h)
+ h = nn.gelu(h)
+ h = nn.Dense(self.d_model, kernel_init=_INIT, bias_init=_BIAS)(h)
+ h = nn.Dropout(rate=self.dropout_rate, deterministic=self.deterministic)(h)
+ return x + h
+
+
+class DenoisingTransformer(nn.Module):
+ """Denoising transformer for masked discrete diffusion planning.
+
+ Input: (obs [B, D], noisy_actions [B, H], timestep [B])
+ Output: logits [B, H, num_actions] (no MASK logit).
+ """
+
+ num_actions: int
+ plan_horizon: int
+ d_model: int = 256
+ n_heads: int = 4
+ n_layers: int = 4
+ d_ff: int = 512
+ obs_encoder_layers: int = 2
+ obs_encoder_width: int = 512
+ dropout_rate: float = 0.1
+
+ @nn.compact
+ def __call__(
+ self,
+ obs: jnp.ndarray,
+ noisy_actions: jnp.ndarray,
+ timestep: jnp.ndarray,
+ deterministic: bool = True,
+ ) -> jnp.ndarray:
+ B = obs.shape[0]
+ vocab = self.num_actions + 1 # +1 for MASK token
+
+ # Observation encoder
+ h = nn.Dense(self.obs_encoder_width, kernel_init=_INIT, bias_init=_BIAS)(obs)
+ h = nn.LayerNorm()(h)
+ h = nn.relu(h)
+ for _ in range(self.obs_encoder_layers - 1):
+ h = nn.Dense(self.obs_encoder_width, kernel_init=_INIT, bias_init=_BIAS)(h)
+ h = nn.relu(h)
+ obs_tok = nn.Dense(self.d_model, kernel_init=_INIT, bias_init=_BIAS)(h)[:, None, :]
+
+ # Time embedding
+ t = timestep.reshape(B)
+ t_emb = SinusoidalPosEmbed(self.d_model)(t)
+ t_emb = nn.Dense(self.d_model, kernel_init=_INIT, bias_init=_BIAS)(t_emb)
+ t_emb = nn.gelu(t_emb)
+ t_tok = nn.Dense(self.d_model, kernel_init=_INIT, bias_init=_BIAS)(t_emb)[:, None, :]
+
+ # Action token embedding
+ act_emb = nn.Embed(num_embeddings=vocab, features=self.d_model)(noisy_actions)
+
+ # Assemble sequence: [obs, time, actions]
+ seq = jnp.concatenate([obs_tok, t_tok, act_emb], axis=1)
+ seq_len = 2 + self.plan_horizon
+ pos_emb = SinusoidalPosEmbed(self.d_model)(jnp.arange(seq_len))
+ seq = seq + pos_emb[None, :, :]
+
+ # Transformer
+ for _ in range(self.n_layers):
+ seq = TransformerBlock(
+ d_model=self.d_model, n_heads=self.n_heads, d_ff=self.d_ff,
+ dropout_rate=self.dropout_rate, deterministic=deterministic,
+ )(seq)
+ seq = nn.LayerNorm()(seq)
+
+ # Output logits over real actions (skip 2 prefix tokens)
+ return nn.Dense(self.num_actions, kernel_init=_INIT_SMALL, bias_init=_BIAS)(
+ seq[:, 2:, :]
+ )
diff --git a/src/planners/__init__.py b/src/planners/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/planners/collect.py b/src/planners/collect.py
new file mode 100644
index 0000000000000000000000000000000000000000..7996d982285bca2214ef0d9a3096ff245351aa7d
--- /dev/null
+++ b/src/planners/collect.py
@@ -0,0 +1,75 @@
+"""Collect offline trajectories from a trained PPO agent."""
+
+from __future__ import annotations
+
+import pathlib
+from typing import Any
+
+import jax
+import jax.numpy as jnp
+import numpy as np
+
+from .ppo import load_ppo_agent
+from .env import make_env
+
+
+def collect_offline_data(config: dict[str, Any]) -> None:
+ """Roll out a PPO agent and save (obs, actions, rewards, dones) to disk.
+
+ Args:
+ config: Upper-cased hyperparameter dict. Must contain
+ ``PPO_CHECKPOINT_PATH``, ``OFFLINE_DATA_PATH``,
+ ``COLLECT_NUM_STEPS``, and ``COLLECT_NUM_ENVS``.
+ """
+ assert config.get("PPO_CHECKPOINT_PATH"), (
+ "--ppo_checkpoint_path is required for --mode collect."
+ )
+
+ num_envs: int = config["COLLECT_NUM_ENVS"]
+ num_iters: int = config["COLLECT_NUM_STEPS"] // num_envs
+
+ env_w, env_params = make_env(config, num_envs)
+ num_actions = env_w.action_space(env_params).n
+ obs_dim = env_w.observation_space(env_params).shape[0]
+
+ ppo = load_ppo_agent(
+ config["PPO_CHECKPOINT_PATH"], num_actions, obs_dim,
+ config.get("LAYER_SIZE", 512),
+ model_type=config.get("PPO_MODEL_TYPE", "ppo_rnn"),
+ config=config, num_envs=num_envs,
+ )
+
+ rng = jax.random.PRNGKey(config["SEED"])
+ rng, env_rng, collect_rng = jax.random.split(rng, 3)
+ obs, env_state = env_w.reset(env_rng, env_params)
+ done = jnp.zeros(num_envs, dtype=bool)
+ hidden = ppo.init_hidden(num_envs)
+
+ def _step(carry, _):
+ rng, es, obs, done, hs = carry
+ rng, act_rng, step_rng = jax.random.split(rng, 3)
+ action, new_hs = ppo.act(obs, done, hs, act_rng,
+ temperature=config.get("COLLECT_TEMPERATURE", 1.0))
+ obs_next, es, reward, done_next, _ = env_w.step(step_rng, es, action, env_params)
+ return (rng, es, obs_next, done_next, new_hs), (obs, action, reward, done)
+
+ rollout_fn = jax.jit(lambda c: jax.lax.scan(_step, c, None, length=num_iters))
+ _, (obs_arr, act_arr, rew_arr, done_arr) = rollout_fn(
+ (collect_rng, env_state, obs, done, hidden),
+ )
+
+ # [steps, envs, ...] -> [envs, steps, ...]
+ obs_np = np.array(obs_arr).transpose(1, 0, 2)
+ act_np = np.array(act_arr).transpose(1, 0)
+ rew_np = np.array(rew_arr).transpose(1, 0)
+ done_np = np.array(done_arr).transpose(1, 0)
+
+ out_path = config["OFFLINE_DATA_PATH"]
+ pathlib.Path(out_path).parent.mkdir(parents=True, exist_ok=True)
+ np.savez(out_path, obs=obs_np, actions=act_np, rewards=rew_np, dones=done_np)
+ total = obs_np.shape[0] * obs_np.shape[1]
+ print(f"Saved {obs_np.shape[0]}×{obs_np.shape[1]} transitions ({total:,}) to '{out_path}'")
+
+
+def run_collect(config: dict[str, Any]) -> None:
+ collect_offline_data(config)
diff --git a/src/planners/common.py b/src/planners/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd52d45925cc654d53e39287e351f7177da6eac6
--- /dev/null
+++ b/src/planners/common.py
@@ -0,0 +1,458 @@
+"""Shared gradient-step factory, validation rollout, and action diagnostics.
+
+Both :mod:`src.planners.offline` and :mod:`src.planners.online` use identical
+gradient update and validation logic. Centralising it here eliminates
+duplication.
+"""
+
+from __future__ import annotations
+
+from typing import Any, Callable
+
+import jax
+import jax.numpy as jnp
+import optax
+
+from src.diffusion.loss import compute_loss
+from src.diffusion.sampling import sample_plan
+from src.diffusion.schedules import ScheduleFn
+
+
+def resolve_num_updates(config: dict[str, Any], mode: str) -> None:
+ """Resolve ``NUM_UPDATES`` from env-frame-denominated config keys.
+
+ Mutates ``config`` in place. After this call the runners can read
+ ``NUM_UPDATES`` (and ``OFFLINE_TOTAL_TIMESTEPS`` /
+ ``ONLINE_TOTAL_TIMESTEPS`` depending on mode) without worrying about
+ whether the user specified the env-frame or update-count form.
+
+ Resolution priority:
+
+ =========== ============================================================
+ Mode Priority (highest first)
+ =========== ============================================================
+ ``offline`` ``OFFLINE_TOTAL_TIMESTEPS`` > ``OFFLINE_NUM_UPDATES``
+ ``online`` ``ONLINE_TOTAL_TIMESTEPS`` > ``ONLINE_NUM_UPDATES``
+ =========== ============================================================
+
+ Env-frame keys are preferred because they are invariant under
+ ``num_envs`` changes — the same value yields the same total environment
+ experience regardless of hardware sizing, which makes cross-hardware
+ fairness studies (e.g. UCL 4096-env vs QMUL 96-env) trivially fair
+ without manual scaling.
+
+ The function is idempotent: calling it twice with the same config has
+ the same effect as calling it once.
+
+ Args:
+ config: Upper-cased config dict. Must contain ``NUM_STEPS`` and
+ ``NUM_ENVS``.
+ mode: Either ``"offline"`` or ``"online"``.
+
+ Raises:
+ ValueError: If neither the env-frame nor the update-count form is
+ set for the given mode, or if ``mode`` is unknown.
+ """
+ frames_per_update = int(config["NUM_STEPS"]) * int(config["NUM_ENVS"])
+
+ if mode == "offline":
+ ts_key, nu_key = "OFFLINE_TOTAL_TIMESTEPS", "OFFLINE_NUM_UPDATES"
+ elif mode == "online":
+ ts_key, nu_key = "ONLINE_TOTAL_TIMESTEPS", "ONLINE_NUM_UPDATES"
+ else:
+ raise ValueError(
+ f"Unknown mode: {mode!r}; expected 'offline' or 'online'."
+ )
+
+ ts = config.get(ts_key)
+ nu = config.get(nu_key)
+ # float() first to accept YAML scientific notation parsed as string
+ # (PyYAML 1.1 only auto-coerces "3.0e+8", not "3e8" or "3.0e8").
+ if ts is not None:
+ num_updates = max(1, int(float(ts)) // frames_per_update)
+ elif nu:
+ num_updates = int(float(nu))
+ else:
+ raise ValueError(
+ f"{mode.capitalize()} mode requires either "
+ f"{ts_key.lower()!r} (env frames, preferred) or "
+ f"{nu_key.lower()!r} to be set."
+ )
+
+ config["NUM_UPDATES"] = num_updates
+ # Re-snap so downstream consumers (run names, SPS, checkpoint IDs)
+ # see the exact integer multiple actually trained.
+ config[ts_key] = num_updates * frames_per_update
+
+
+def resolve_scaled_hyperparams(config: dict[str, Any], mode: str) -> None:
+ """Resolve env-frame-denominated hyperparameters into update-step form.
+
+ Mutates ``config`` in place. PRIMARY (env-frame) keys override LEGACY
+ (update-step) keys when set, mirroring the
+ :func:`resolve_num_updates` pattern. When the PRIMARY key is ``None``
+ the LEGACY value passes through unchanged, preserving full
+ backward compatibility with configs that predate this resolver.
+
+ Resolution table
+ ================
+
+ +-----------------------------+------------------------+----------+
+ | PRIMARY (env-frame) | LEGACY (update-step) | Mode |
+ +=============================+========================+==========+
+ | ``LR_WARMUP_FRAMES`` | ``LR_WARMUP_STEPS`` | both |
+ +-----------------------------+------------------------+----------+
+ | ``VAL_INTERVAL_FRAMES`` | ``VAL_INTERVAL`` | both |
+ +-----------------------------+------------------------+----------+
+ | ``DAGGER_BETA_FINAL`` | ``DAGGER_BETA_DECAY`` | online |
+ +-----------------------------+------------------------+----------+
+ | ``DAGGER_BUFFER_CYCLES`` | ``DAGGER_BUFFER_MAX`` | online |
+ +-----------------------------+------------------------+----------+
+
+ Why env-frame units
+ -------------------
+ Env-frame values are invariant under ``num_envs`` changes, so the
+ same config trains the same effective experiment on any GPU. The
+ update-step legacy keys had to be hand-derived per hardware tier,
+ which was both error-prone and obscured the conceptual quantity
+ (e.g. *final beta*, not *per-update decay constant*).
+
+ The conversion for ``DAGGER_BETA_FINAL`` requires ``NUM_UPDATES``,
+ so this function MUST be called after :func:`resolve_num_updates`
+ when resolving online mode.
+
+ Idempotent: calling this twice is equivalent to calling it once.
+
+ Args:
+ config: Upper-cased config dict. Must contain ``NUM_STEPS`` and
+ ``NUM_ENVS``.
+ mode: Either ``"offline"`` or ``"online"``.
+
+ Raises:
+ ValueError: If ``DAGGER_BETA_FINAL`` is set in online mode but
+ ``NUM_UPDATES`` has not been resolved yet.
+ """
+ fpu = int(config["NUM_STEPS"]) * int(config["NUM_ENVS"])
+
+ # float() first to accept YAML scientific notation parsed as string
+ # (PyYAML 1.1 only auto-coerces "3.0e+8", not "3e8" or "3.0e8").
+ # ── Mode-agnostic ────────────────────────────────────────────────
+ warmup_frames = config.get("LR_WARMUP_FRAMES")
+ if warmup_frames is not None:
+ config["LR_WARMUP_STEPS"] = int(float(warmup_frames)) // fpu
+
+ val_frames = config.get("VAL_INTERVAL_FRAMES")
+ if val_frames is not None:
+ config["VAL_INTERVAL"] = max(1, int(float(val_frames)) // fpu)
+
+ # ── Online-only ──────────────────────────────────────────────────
+ if mode != "online":
+ return
+
+ beta_final = config.get("DAGGER_BETA_FINAL")
+ if beta_final is not None:
+ num_updates = config.get("NUM_UPDATES")
+ if num_updates is None:
+ raise ValueError(
+ "DAGGER_BETA_FINAL requires NUM_UPDATES to be resolved "
+ "first; call resolve_num_updates() before "
+ "resolve_scaled_hyperparams()."
+ )
+ beta_init = float(config.get("DAGGER_BETA_INIT", 1.0))
+ # final = init * decay^N => decay = (final / init) ** (1 / N)
+ config["DAGGER_BETA_DECAY"] = (
+ float(beta_final) / beta_init
+ ) ** (1.0 / int(num_updates))
+
+ buffer_cycles = config.get("DAGGER_BUFFER_CYCLES")
+ if buffer_cycles is not None:
+ config["DAGGER_BUFFER_MAX"] = max(1, int(round(float(buffer_cycles) * fpu)))
+
+
+def print_config_snapshot(config: dict[str, Any], mode: str) -> None:
+ """Print a structured banner of training-critical hyperparameters.
+
+ Surfaces fairness-critical, schedule, and architecture parameters at
+ the start of every offline/online run so cross-hardware comparisons
+ can be sanity-checked at a glance. Must be called AFTER
+ :func:`resolve_num_updates` and :func:`resolve_scaled_hyperparams`
+ so the printed values reflect what training will actually use.
+
+ Args:
+ config: Upper-cased config dict (post-resolver).
+ mode: Either ``"offline"`` or ``"online"``.
+ """
+ fpu = int(config["NUM_STEPS"]) * int(config["NUM_ENVS"])
+ num_updates = int(config["NUM_UPDATES"])
+ minibatch = fpu // int(config["NUM_MINIBATCHES"])
+ ts_key = f"{mode.upper()}_TOTAL_TIMESTEPS"
+ total_frames = int(config[ts_key])
+
+ bar = "=" * 72
+ title = f"{mode.upper()} training — config snapshot"
+ print(f"\n{bar}\n {title}\n{bar}")
+ print(f" env_name : {config['ENV_NAME']}")
+ print(f" seed : {config['SEED']}")
+
+ print(" -- Rollout / hardware --")
+ print(f" num_envs = {config['NUM_ENVS']}")
+ print(f" num_steps = {config['NUM_STEPS']}")
+ print(f" fpu (envs*steps) = {fpu}")
+ print(f" num_minibatches = {config['NUM_MINIBATCHES']} (minibatch={minibatch})")
+ print(f" update_epochs = {config['UPDATE_EPOCHS']}")
+ print(f" num_repeats = {config.get('NUM_REPEATS', 1)}")
+
+ print(" -- Schedule --")
+ print(f" {ts_key.lower():<24} = {total_frames:,} (~{total_frames/1e6:.1f}M frames)")
+ print(f" {'num_updates':<24} = {num_updates:,}")
+ warmup = int(config.get("LR_WARMUP_STEPS", 0))
+ print(f" {'lr':<24} = {float(config['LR']):.2e}")
+ print(f" {'lr_warmup_steps':<24} = {warmup} (~{warmup * fpu / 1e6:.2f}M frames)")
+ print(f" {'max_grad_norm':<24} = {config.get('MAX_GRAD_NORM', 1.0)}")
+
+ if mode == "online":
+ beta_init = float(config.get("DAGGER_BETA_INIT", 1.0))
+ beta_decay = float(config["DAGGER_BETA_DECAY"])
+ final_beta = beta_init * beta_decay ** num_updates
+ buffer_max = int(config["DAGGER_BUFFER_MAX"])
+ cycles = buffer_max / fpu
+ # Mirrors the n_train_passes default in run_online: drawn fresh per
+ # update, capped at samples_per_update for memory.
+ plan_h = int(config["PLAN_HORIZON"])
+ samples_per_update = int(config["NUM_ENVS"]) * (
+ int(config["NUM_STEPS"]) - plan_h + 1
+ )
+ n_passes = config.get("DAGGER_TRAIN_PASSES") or max(
+ 1, buffer_max // max(1, samples_per_update)
+ )
+ expert_det = bool(config.get("DAGGER_EXPERT_DETERMINISTIC", True))
+ total_grad_steps = (
+ num_updates * int(n_passes)
+ * int(config["UPDATE_EPOCHS"]) * int(config["NUM_MINIBATCHES"])
+ )
+ passes_tag = "auto" if config.get("DAGGER_TRAIN_PASSES") is None else "override"
+ print(" -- DAgger --")
+ print(f" {'dagger_beta_init':<24} = {beta_init}")
+ print(f" {'dagger_beta_decay':<24} = {beta_decay:.10f}")
+ print(f" {'final beta':<24} = {final_beta:.4f} (init * decay^N)")
+ print(f" {'dagger_buffer_max':<24} = {buffer_max:,} (~{cycles:.2f} update cycles)")
+ print(f" {'samples_per_update':<24} = {samples_per_update:,}")
+ print(f" {'dagger_train_passes':<24} = {n_passes} ({passes_tag})")
+ print(f" {'dagger_expert_determ':<24} = {expert_det}")
+ print(f" {'total_grad_steps':<24} = {total_grad_steps:,}")
+ else:
+ total_grad_steps = (
+ num_updates * int(config["UPDATE_EPOCHS"]) * int(config["NUM_MINIBATCHES"])
+ )
+ print(f" {'total_grad_steps':<24} = {total_grad_steps:,}")
+
+ val_int = int(config.get("VAL_INTERVAL", 0))
+ print(" -- Validation --")
+ print(f" val_interval = {val_int} updates (~{val_int * fpu / 1e6:.2f}M frames)")
+ print(f" val_diffusion_steps = {config.get('VAL_DIFFUSION_STEPS')}")
+ print(f" val_replan_every = {config.get('VAL_REPLAN_EVERY')}")
+ print(f" val_steps = {config.get('VAL_STEPS')}")
+
+ print(" -- Diffusion model --")
+ print(
+ f" d_model/n_heads/n_layers/d_ff = "
+ f"{config['D_MODEL']}/{config['N_HEADS']}/{config['N_LAYERS']}/{config['D_FF']}"
+ )
+ print(f" plan_horizon = {config['PLAN_HORIZON']}")
+ print(f" diffusion_steps = {config['DIFFUSION_STEPS']}")
+ print(f" remask_strategy = {config.get('REMASK_STRATEGY')} eta={config.get('ETA')}")
+ print(
+ f" sampling: temp={config.get('TEMPERATURE')} top_p={config.get('TOP_P')} "
+ f"loop={config.get('USE_LOOP')} t_on/t_off={config.get('T_ON')}/{config.get('T_OFF')}"
+ )
+ print(f"{bar}\n", flush=True)
+
+
+def _action_stats(
+ acts: jnp.ndarray,
+ num_actions: int,
+ valid: jnp.ndarray,
+) -> dict[str, jnp.ndarray]:
+ """Compute action-distribution entropy and unique-action fraction over valid windows.
+
+ Args:
+ acts: ``[B, H]`` int32 action sequences.
+ num_actions: Size of the real action vocabulary.
+ valid: ``[B]`` bool mask; invalid samples are excluded from counts.
+
+ Returns:
+ Dict with ``action_entropy`` and ``action_unique_frac``.
+ """
+ mask = jnp.broadcast_to(valid[:, None], acts.shape).reshape(-1)
+ flat = jnp.where(mask, acts.reshape(-1), num_actions + 1)
+ counts = jnp.bincount(flat, length=num_actions).astype(jnp.float32)
+ probs = counts / jnp.maximum(counts.sum(), 1.0)
+ entropy = -jnp.sum(probs * jnp.log(jnp.where(probs > 0, probs, 1.0)))
+ return {
+ "action_entropy": entropy,
+ "action_unique_frac": jnp.sum(probs > 0).astype(jnp.float32) / num_actions,
+ }
+
+
+def make_grad_step(
+ apply_train: Callable,
+ num_actions: int,
+ schedule_fn: ScheduleFn,
+ schedule_deriv_fn: ScheduleFn,
+ sigma_t: float,
+ label_smoothing: float,
+) -> Callable:
+ """Return a jittable gradient update function.
+
+ Args:
+ apply_train: Model apply function with dropout enabled.
+ num_actions: Size of the action vocabulary.
+ schedule_fn: alpha(t) noise schedule.
+ schedule_deriv_fn: d(alpha)/dt analytic derivative.
+ sigma_t: ReMDM remasking strength during training.
+ label_smoothing: Cross-entropy label smoothing epsilon.
+
+ Returns:
+ A ``step(state, acts, obs, valid, rng, advantages) -> (state, metrics)``
+ function ready for use inside ``jax.lax.scan``.
+ """
+
+ def _loss_fn(
+ params: Any,
+ acts: jnp.ndarray,
+ obs: jnp.ndarray,
+ valid: jnp.ndarray,
+ rng: jax.Array,
+ advantages: jnp.ndarray,
+ ) -> tuple[jnp.ndarray, dict]:
+ return compute_loss(
+ apply_train, params, rng, acts, obs, valid,
+ num_actions, schedule_fn, schedule_deriv_fn,
+ sigma_t=sigma_t, label_smoothing=label_smoothing,
+ advantages=advantages,
+ )
+
+ def step(
+ state: Any,
+ acts: jnp.ndarray,
+ obs: jnp.ndarray,
+ valid: jnp.ndarray,
+ rng: jax.Array,
+ advantages: jnp.ndarray,
+ ) -> tuple[Any, dict]:
+ """Single gradient update step.
+
+ Args:
+ state: Current ``TrainState``.
+ acts: ``[B, H]`` int32 action sequences.
+ obs: ``[B, obs_dim]`` float32 observations.
+ valid: ``[B]`` bool validity mask (episode-boundary filter).
+ rng: PRNG key for dropout and noise sampling.
+ advantages: ``[B]`` float per-sample weights applied before loss reduction.
+
+ Returns:
+ Updated ``TrainState`` and a metrics dict.
+ """
+ (_, info), grads = jax.value_and_grad(_loss_fn, has_aux=True)(
+ state.params, acts, obs, valid, rng, advantages,
+ )
+ state = state.apply_gradients(grads=grads)
+ info["grad_norm"] = optax.tree.norm(grads)
+ info.update(_action_stats(acts, num_actions, valid))
+ return state, info
+
+ return step
+
+
+def make_validate(
+ env: Any,
+ env_params: Any,
+ apply_eval: Callable,
+ num_actions: int,
+ plan_horizon: int,
+ schedule_fn: ScheduleFn,
+ config: dict[str, Any],
+ val_replan_every: int,
+ n_val_cycles: int,
+) -> Callable:
+ """Return a ``validate(state, rng) -> dict`` closure for periodic eval.
+
+ The closure runs a held-out rollout using the diffusion model's current
+ parameters and returns metrics under the ``val/`` namespace.
+
+ Args:
+ env: Batched Gymnax environment.
+ env_params: Gymnax environment params.
+ apply_eval: Model apply function (eval mode, no dropout).
+ num_actions: Size of the action vocabulary.
+ plan_horizon: Action plan length H.
+ schedule_fn: alpha(t) noise schedule.
+ config: Training config dict (read-only).
+ val_replan_every: Env steps executed per diffusion plan during validation.
+ n_val_cycles: Number of plan-execute cycles per validation rollout.
+
+ Returns:
+ A ``validate(state, rng) -> {str: jnp.ndarray}`` closure.
+ """
+
+ def validate(state: Any, rng: jax.Array) -> dict[str, jnp.ndarray]:
+ """Run a validation rollout and return ``val/`` metrics.
+
+ Args:
+ state: Current ``TrainState`` (only ``.params`` is used).
+ rng: PRNG key.
+
+ Returns:
+ Dict with ``val/`` prefixed metric keys.
+ """
+ rng, val_rng = jax.random.split(rng)
+ val_obs, val_env_state = env.reset(val_rng, env_params)
+
+ def _val_cycle(carry, _):
+ vs, vo, rng = carry
+ rng, p_rng = jax.random.split(rng)
+ plan = sample_plan(
+ apply_eval,
+ state.params,
+ p_rng,
+ vo,
+ num_actions,
+ plan_horizon,
+ num_steps=config.get("VAL_DIFFUSION_STEPS", 50),
+ schedule_fn=schedule_fn,
+ remask_strategy=config.get("REMASK_STRATEGY", "rescale"),
+ eta=config.get("ETA", 0.5),
+ use_loop=config.get("USE_LOOP", True),
+ t_on=config.get("T_ON", 0.7),
+ t_off=config.get("T_OFF", 0.3),
+ temperature=config.get("TEMPERATURE", 0.5),
+ top_p=config.get("TOP_P", 0.95),
+ ) # [num_envs, plan_horizon]
+
+ def _exec_step(inner_carry, step_i):
+ vs_i, vo_i, r = inner_carry
+ r, s_rng = jax.random.split(r)
+ vo_next, vs_next, _, _, info = env.step(
+ s_rng, vs_i, plan[:, step_i], env_params,
+ )
+ return (vs_next, vo_next, r), info
+
+ (vs, vo, rng), step_infos = jax.lax.scan(
+ _exec_step, (vs, vo, rng), jnp.arange(val_replan_every),
+ )
+ return (vs, vo, rng), step_infos
+
+ _, cycle_infos = jax.lax.scan(
+ _val_cycle, (val_env_state, val_obs, rng), None, n_val_cycles,
+ )
+ infos = jax.tree.map(
+ lambda x: x.reshape(-1, *x.shape[2:]), cycle_infos,
+ )
+ returned = infos["returned_episode"]
+ metrics = jax.tree.map(
+ lambda x: (x * returned).sum() / (returned.sum() + 1e-8),
+ infos,
+ )
+ return {f"val/{k}": v for k, v in metrics.items()}
+
+ return validate
diff --git a/src/planners/env.py b/src/planners/env.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac6f69b50038c44bd4b61268a6c43f845da31470
--- /dev/null
+++ b/src/planners/env.py
@@ -0,0 +1,48 @@
+"""Craftax environment construction and trajectory data structures."""
+
+from __future__ import annotations
+
+import jax.numpy as jnp
+from craftax.craftax_env import make_craftax_env_from_name
+from typing import NamedTuple
+
+from Craftax_Baselines.wrappers import (
+ AutoResetEnvWrapper,
+ BatchEnvWrapper,
+ LogWrapper,
+ OptimisticResetVecEnvWrapper,
+)
+
+
+class Transition(NamedTuple):
+ done: jnp.ndarray
+ action: jnp.ndarray
+ reward: jnp.ndarray
+ obs: jnp.ndarray
+ info: dict
+
+
+def make_env(config: dict, num_envs: int):
+ """Build a wrapped Craftax environment.
+
+ Args:
+ config: Upper-cased config dict with ``ENV_NAME``,
+ ``USE_OPTIMISTIC_RESETS``, ``OPTIMISTIC_RESET_RATIO``.
+ num_envs: Number of parallel environments.
+
+ Returns:
+ Tuple of ``(env, env_params)``.
+ """
+ env = make_craftax_env_from_name(config["ENV_NAME"], not config["USE_OPTIMISTIC_RESETS"])
+ env_params = env.default_params
+ env = LogWrapper(env)
+ if config["USE_OPTIMISTIC_RESETS"]:
+ env = OptimisticResetVecEnvWrapper(
+ env,
+ num_envs=num_envs,
+ reset_ratio=min(config["OPTIMISTIC_RESET_RATIO"], num_envs),
+ )
+ else:
+ env = AutoResetEnvWrapper(env)
+ env = BatchEnvWrapper(env, num_envs=num_envs)
+ return env, env_params
diff --git a/src/planners/inference.py b/src/planners/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..06e59e43a4771f38819e1a07b7ba2c30dbc35739
--- /dev/null
+++ b/src/planners/inference.py
@@ -0,0 +1,142 @@
+"""Evaluation: run a trained diffusion planner with MPC + historical inpainting."""
+
+from __future__ import annotations
+
+import time
+from typing import Any
+
+import jax
+import jax.numpy as jnp
+import numpy as np
+import wandb
+from craftax.craftax_env import make_craftax_env_from_name
+from craftax.craftax.constants import Achievement as FullCraftaxAchievements
+from craftax.craftax_classic.constants import Achievement as ClassicAchievements
+
+from src.diffusion.sampling import sample_plan_inpainting
+from .model import build_model, load_checkpoint, make_apply_fns
+
+
+# ---------------------------------------------------------------------------
+# Entry point
+# ---------------------------------------------------------------------------
+
+def run_inference(config: dict[str, Any]) -> None:
+ env_name = config["ENV_NAME"]
+ env = make_craftax_env_from_name(env_name, auto_reset=True)
+ env_params = env.default_params
+ num_actions = env.action_space(env_params).n
+ obs_dim = env.observation_space(env_params).shape[0]
+ config["NUM_ACTIONS"] = num_actions
+
+ num_envs = config.get("EVAL_NUM_ENVS", 32)
+ plan_horizon = config["PLAN_HORIZON"]
+ diffusion_steps = config.get("DIFFUSION_STEPS_EVAL", 10)
+ temperature = config.get("TEMPERATURE", 0.5)
+ top_p = config.get("TOP_P", 0.95)
+ eval_steps = int(float(config.get("EVAL_STEPS", 10000)))
+
+ model = build_model(config, num_actions)
+ apply_eval, _ = make_apply_fns(model)
+
+ rng = jax.random.PRNGKey(config["SEED"])
+ rng, ckpt_rng = jax.random.split(rng)
+ model_params = load_checkpoint(model, ckpt_rng, obs_dim, plan_horizon, config["CHECKPOINT_PATH"])
+ env_indices = jnp.arange(num_envs)
+
+ @jax.jit
+ def mpc_step(carry, _step_idx):
+ obs, state, rng, history, hist_len = carry
+ rng, plan_rng, env_rng = jax.random.split(rng, 3)
+
+ # Reset history when plan is exhausted
+ seq_full = hist_len >= plan_horizon
+ hist_len = jnp.where(seq_full, 0, hist_len)
+ history = jnp.where(seq_full[:, None], num_actions, history)
+
+ plan = sample_plan_inpainting(
+ apply_eval, model_params, plan_rng, obs,
+ history, hist_len, num_actions, plan_horizon,
+ diffusion_steps, temperature, top_p,
+ )
+
+ action = jnp.take_along_axis(plan, hist_len[:, None], axis=-1).squeeze(-1)
+ history = history.at[env_indices, hist_len].set(action)
+ hist_len = hist_len + 1
+
+ obs_next, state_next, reward, done, info = jax.vmap(env.step, in_axes=(0, 0, 0, None))(
+ jax.random.split(env_rng, num_envs), state, action, env_params,
+ )
+
+ hist_len = jnp.where(done, 0, hist_len)
+ history = jnp.where(done[:, None], num_actions, history)
+ return (obs_next, state_next, rng, history, hist_len), (reward, done, state_next.achievements)
+
+ print(f"\nRunning {num_envs} agents in {env_name} for {eval_steps} steps...")
+
+ rng, env_rng = jax.random.split(rng)
+ obs, state = jax.vmap(env.reset, in_axes=(0, None))(
+ jax.random.split(env_rng, num_envs), env_params,
+ )
+ history = jnp.full((num_envs, plan_horizon), num_actions, dtype=jnp.int32)
+ hist_len = jnp.zeros(num_envs, dtype=jnp.int32)
+
+ t0 = time.time()
+ _, (rewards, dones, achievements) = jax.lax.scan(
+ mpc_step, (obs, state, rng, history, hist_len), jnp.arange(eval_steps),
+ )
+ elapsed = time.time() - t0
+
+ # First-episode extraction (strict single-life evaluation)
+ rewards_np = np.array(rewards)
+ dones_np = np.array(dones)
+ ach_np = np.array(achievements)
+
+ ep_rewards = np.zeros(num_envs)
+ ep_ach = np.zeros((num_envs, ach_np.shape[2]))
+ ep_lengths = np.zeros(num_envs, dtype=int)
+
+ for i in range(num_envs):
+ death = np.where(dones_np[:, i])[0]
+ end = death[0] if len(death) > 0 else eval_steps - 1
+ ep_rewards[i] = rewards_np[:end + 1, i].sum()
+ ep_ach[i] = ach_np[:end + 1, i].max(axis=0)
+ ep_lengths[i] = end + 1
+
+ pct = ep_ach.mean(axis=0) * 100.0
+
+ # Report
+ print(f"\n{'=' * 50}")
+ print(f"EVALUATION COMPLETE ({elapsed:.1f}s)")
+ print(f"{'=' * 50}")
+ print(f"Average Score: {ep_rewards.mean():.1f} | Best: {ep_rewards.max():.1f}")
+
+ ach_cls = ClassicAchievements if "Classic" in env_name else FullCraftaxAchievements
+ ach_names = [(a.name.replace("_", " ").title(), a.name.lower()) for a in ach_cls]
+ valid = [i for i, p in enumerate(pct) if p > 0]
+ top_idx = max(valid) if valid else 5
+
+ for i in range(top_idx + 1):
+ name, _ = ach_names[i]
+ count = int(pct[i] / 100.0 * num_envs)
+ icon = "+" if count > 0 else "-"
+ print(f" [{icon}] {name}: {count}/{num_envs}")
+ print(f"{'=' * 50}")
+
+ if config.get("USE_WANDB", True):
+ wandb.init(
+ project=config.get("WANDB_PROJECT", "remdm-craftax"),
+ name=f"Eval-T{temperature}-P{top_p}",
+ config=config, job_type="evaluation",
+ )
+ summary = {"eval/average_score": float(ep_rewards.mean())}
+ for i in range(top_idx + 1):
+ summary[f"eval/achievements/{ach_names[i][1]}"] = pct[i]
+ wandb.log(summary)
+
+ table = wandb.Table(columns=["Agent", "Score", "Achievements", "Lifespan"])
+ unlocked = ep_ach.sum(axis=-1)
+ for i in range(num_envs):
+ table.add_data(f"Agent {i + 1}", float(ep_rewards[i]), int(unlocked[i]), int(ep_lengths[i]))
+ wandb.log({"Individual Results": table})
+ wandb.finish()
diff --git a/src/planners/logging.py b/src/planners/logging.py
new file mode 100644
index 0000000000000000000000000000000000000000..f6c824ce42baf6aa0aa916fdc6d154dcc6e8ea31
--- /dev/null
+++ b/src/planners/logging.py
@@ -0,0 +1,209 @@
+"""Centralised W&B logging utilities for ReMDM training loops.
+
+Design principles
+-----------------
+* ``build_log_dict`` is a pure function — no side effects, fully testable.
+* ``make_wandb_callback`` is a factory that returns a closure suitable for
+ ``jax.debug.callback``. All timing state is local to the closure; there is
+ no module-level global state.
+* Both ``train.py`` and ``online.py`` import the same symbols, keeping all
+ metric naming and aggregation logic in one place.
+
+Metric namespacing
+------------------
+``diffusion/`` — ELBO loss, accuracy, and noise-level diagnostics from ``compute_loss``.
+``train/`` — data quality, action distribution, throughput.
+``env/`` — episode returns and per-achievement unlock rates (training envs).
+``val/`` — same as ``env/`` but from the held-out validation rollout
+ (only emitted when ``step_idx % val_interval == 0``).
+``dagger/`` — DAgger-specific metrics (online training only).
+"""
+
+from __future__ import annotations
+
+import time
+from typing import Any
+
+import wandb
+
+
+def init_wandb(
+ config: dict[str, Any],
+ name: str,
+ *,
+ resume_run_id: str | None = None,
+) -> None:
+ """Initialise a W&B run, optionally resuming an existing one.
+
+ Args:
+ config: Training config dict (used for ``project``, ``entity``,
+ and logged as run config).
+ name: Human-readable run name.
+ resume_run_id: If provided, attaches to an existing W&B run via
+ ``wandb.init(id=..., resume="must")``. The run must
+ already exist.
+ """
+ kwargs: dict[str, Any] = {
+ "project": config.get("WANDB_PROJECT", "remdm-craftax"),
+ "entity": config.get("WANDB_ENTITY"),
+ "config": config,
+ }
+ if resume_run_id is not None:
+ kwargs["id"] = resume_run_id
+ kwargs["resume"] = "must"
+ else:
+ kwargs["name"] = name
+
+ wandb.init(**kwargs)
+
+
+# Keys emitted by ``src.diffusion.loss.compute_loss`` info dict.
+_DIFFUSION_KEYS: tuple[str, ...] = (
+ "loss",
+ "unweighted_loss",
+ "accuracy",
+ "acc_t_low",
+ "acc_t_mid",
+ "acc_t_high",
+ "frac_masked",
+ "mean_t",
+ "grad_norm",
+)
+
+# Keys added locally by training loops.
+_TRAIN_KEYS: tuple[str, ...] = (
+ "action_entropy",
+ "action_unique_frac",
+ "valid_frac",
+ "mean_return_weight",
+)
+
+# Keys specific to online DAgger training.
+_DAGGER_KEYS: tuple[str, ...] = (
+ "beta",
+ "reward_mean",
+ "buffer_fill",
+ "valid_frac",
+ "best_val_return",
+)
+
+
+def build_log_dict(
+ metric: dict[str, Any],
+ step_idx: int,
+ val_interval: int,
+ *,
+ is_online: bool = False,
+ sps: float | None = None,
+) -> dict[str, float]:
+ """Build a flat W&B-ready log dict from a merged training metric dict.
+
+ Args:
+ metric: Merged metric dict from the current update step.
+ step_idx: Integer update step index.
+ val_interval: How often (in steps) validation runs occur.
+ is_online: If ``True``, emit DAgger-specific keys under ``dagger/``.
+ sps: Pre-computed steps-per-second; omitted when ``None``.
+
+ Returns:
+ Flat ``{str: float}`` dict suitable for ``wandb.log``.
+ """
+ log: dict[str, float] = {}
+ is_val_step = (step_idx % val_interval == 0)
+
+ for k in _DIFFUSION_KEYS:
+ if k in metric:
+ log[f"diffusion/{k}"] = float(metric[k])
+
+ for k in _TRAIN_KEYS:
+ if k in metric:
+ log[f"train/{k}"] = float(metric[k])
+
+ if is_online:
+ for k in _DAGGER_KEYS:
+ if k in metric:
+ log[f"dagger/{k}"] = float(metric[k])
+
+ if "returned_episode_returns" in metric:
+ log["env/episode_return"] = float(metric["returned_episode_returns"])
+ if "returned_episode_lengths" in metric:
+ log["env/episode_length"] = float(metric["returned_episode_lengths"])
+
+ # Per-achievement breakdown + aggregate score (Craftax reports as %, divide by 100).
+ achieve_total = 0.0
+ for k, v in metric.items():
+ if "achievement" in k.lower() and not k.startswith("val/"):
+ log[f"env/achieve/{k}"] = float(v)
+ achieve_total += float(v) / 100.0
+ log["env/achievements"] = achieve_total
+
+ # Validation metrics — only emitted on val steps to avoid polluting charts with zeros.
+ if is_val_step:
+ val_achieve_total = 0.0
+ for k, v in metric.items():
+ if not k.startswith("val/"):
+ continue
+ inner = k[4:] # strip leading "val/"
+ if "achievement" in inner.lower():
+ log[f"val/achieve/{inner}"] = float(v)
+ val_achieve_total += float(v) / 100.0
+ elif inner == "returned_episode_returns":
+ log["val/episode_return"] = float(v)
+ elif inner == "returned_episode_lengths":
+ log["val/episode_length"] = float(v)
+ log["val/achievements"] = val_achieve_total
+
+ if sps is not None:
+ log["train/sps"] = sps
+
+ return log
+
+
+def make_wandb_callback(
+ config: dict[str, Any],
+ *,
+ steps_per_update: int | None,
+ val_interval: int,
+ is_online: bool = False,
+) -> Any:
+ """Return a host-side logging closure for ``jax.debug.callback``.
+
+ The closure tracks wall-clock time between successive calls to compute
+ steps-per-second. All state is local to the closure; there is no
+ module-level mutable state.
+
+ SPS is not reported on ``step_idx == 0`` (JIT compilation overhead) or
+ when ``steps_per_update`` is ``None`` (e.g. data-replay mode where no
+ environment frames are consumed).
+
+ Args:
+ config: Training config dict (read-only; only consulted for
+ ``USE_WANDB`` — callers are expected to guard).
+ steps_per_update: Environment frames consumed per update step. Pass
+ ``None`` to disable ``train/sps`` logging entirely
+ (e.g. when training from pre-collected data files).
+ val_interval: Frequency (in steps) at which validation runs occur.
+ is_online: If ``True``, emit DAgger keys under ``dagger/``.
+
+ Returns:
+ A callable ``log_fn(metric, step_idx) -> None`` for
+ ``jax.debug.callback``.
+ """
+ _t: list[float] = [time.time()]
+
+ def log_fn(metric: dict[str, Any], step_idx: int) -> None:
+ now = time.time()
+ dt = now - _t[0]
+ _t[0] = now
+
+ sps: float | None = (
+ steps_per_update / dt
+ if steps_per_update is not None and int(step_idx) > 0 and dt > 1e-6
+ else None
+ )
+ log = build_log_dict(
+ metric, int(step_idx), val_interval, is_online=is_online, sps=sps,
+ )
+ wandb.log(log, step=int(step_idx))
+
+ return log_fn
diff --git a/src/planners/model.py b/src/planners/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..72afa40ab71779d449a827792303cb4eab44e1a6
--- /dev/null
+++ b/src/planners/model.py
@@ -0,0 +1,325 @@
+"""Diffusion model lifecycle: construction, parameter init, checkpoint I/O, and apply closures."""
+
+from __future__ import annotations
+
+import json
+import logging
+from pathlib import Path
+from typing import Any, Callable, Union
+
+import jax
+import jax.numpy as jnp
+import numpy as np
+import optax
+import orbax.checkpoint as ocp
+from flax.training.train_state import TrainState
+
+from src.models.denoiser import DenoisingTransformer
+
+logger = logging.getLogger(__name__)
+
+_METADATA_FILENAME = "resume_metadata.json"
+
+
+def build_model(config: dict, num_actions: int) -> DenoisingTransformer:
+ """Construct a :class:`DenoisingTransformer` from a config dict.
+
+ Args:
+ config: Upper-cased config dict with architecture hyperparameters.
+ num_actions: Size of the discrete action vocabulary.
+
+ Returns:
+ An uninitialised :class:`DenoisingTransformer` instance.
+ """
+ return DenoisingTransformer(
+ num_actions=num_actions,
+ plan_horizon=config["PLAN_HORIZON"],
+ d_model=config.get("D_MODEL", 256),
+ n_heads=config.get("N_HEADS", 4),
+ n_layers=config.get("N_LAYERS", 4),
+ d_ff=config.get("D_FF", 512),
+ obs_encoder_layers=config.get("OBS_ENCODER_LAYERS", 2),
+ obs_encoder_width=config.get("OBS_ENCODER_WIDTH", 512),
+ dropout_rate=config.get("DROPOUT_RATE", 0.1),
+ )
+
+
+def init_params(
+ model: DenoisingTransformer,
+ rng: jax.Array,
+ obs_dim: int,
+ plan_horizon: int,
+) -> Any:
+ """Initialize model parameters with dummy inputs.
+
+ Args:
+ model: Flax module to initialise.
+ rng: PRNG key.
+ obs_dim: Observation dimensionality.
+ plan_horizon: Number of action steps in a plan.
+
+ Returns:
+ Initialised parameter pytree.
+ """
+ return model.init(
+ rng,
+ jnp.zeros((1, obs_dim)),
+ jnp.zeros((1, plan_horizon), dtype=jnp.int32),
+ jnp.zeros((1,)),
+ )
+
+
+def resolve_checkpoint_path(
+ path: str,
+ download_dir: str | None = None,
+) -> str:
+ """Resolve a checkpoint path, downloading from W&B if it is an artifact reference.
+
+ Paths prefixed with ``wandb:`` are treated as W&B artifact references
+ (e.g. ``wandb:entity/project/name:version``) and downloaded locally
+ before returning the filesystem path.
+
+ Args:
+ path: Local filesystem path or ``wandb:``-prefixed artifact
+ reference.
+ download_dir: Root directory for downloaded artifacts. When ``None``,
+ falls back to the wandb default (``./artifacts/``).
+
+ Returns:
+ Local filesystem path to the checkpoint directory.
+ """
+ if not path.startswith("wandb:"):
+ return str(Path(path).resolve())
+
+ import wandb
+
+ artifact_ref = path.removeprefix("wandb:")
+ api = wandb.Api()
+ artifact = api.artifact(artifact_ref)
+ local_path = (
+ artifact.download(root=download_dir) if download_dir else artifact.download()
+ )
+ print(f"Downloaded W&B artifact '{artifact_ref}' -> '{local_path}'")
+ return local_path
+
+
+def load_checkpoint(
+ model: DenoisingTransformer,
+ rng: jax.Array,
+ obs_dim: int,
+ plan_horizon: int,
+ path: str,
+) -> Any:
+ """Load diffusion model parameters from an Orbax checkpoint.
+
+ Args:
+ model: Flax module (used to build the abstract state structure).
+ rng: PRNG key for dummy initialisation.
+ obs_dim: Observation dimensionality.
+ plan_horizon: Number of action steps in a plan.
+ path: Path to the Orbax checkpoint directory.
+
+ Returns:
+ Restored parameter pytree.
+
+ Raises:
+ FileNotFoundError: If the checkpoint directory contains no saved steps.
+ """
+ path = str(Path(path).resolve())
+ params = init_params(model, rng, obs_dim, plan_horizon)
+ abstract_state = create_train_state(model=model, params=params, lr=1e-4, max_grad_norm=1.0)
+
+ with ocp.CheckpointManager(path) as mgr:
+ step = mgr.latest_step()
+ if step is None:
+ raise FileNotFoundError(f"No checkpoint at {path}")
+ restored_state = mgr.restore(
+ step,
+ args=ocp.args.StandardRestore(item=abstract_state),
+ )
+
+ print(f"Loaded diffusion checkpoint from '{path}' (step {step})")
+ return restored_state.params
+
+
+def create_train_state(
+ model: DenoisingTransformer,
+ params: Any,
+ lr: Union[float, Callable[[int], float]],
+ max_grad_norm: float,
+) -> TrainState:
+ """Create a :class:`TrainState` with gradient clipping and Adam.
+
+ Args:
+ model: Flax module (used only to bind ``apply_fn``).
+ params: Initialised parameter pytree.
+ lr: Constant learning rate or an optax schedule
+ (any callable ``step -> lr``).
+ max_grad_norm: Global gradient clipping threshold.
+
+ Returns:
+ A Flax ``TrainState`` ready for ``apply_gradients``.
+ """
+ tx = optax.chain(optax.clip_by_global_norm(max_grad_norm), optax.adam(lr, eps=1e-5))
+ return TrainState.create(apply_fn=model.apply, params=params, tx=tx)
+
+
+def make_apply_fns(
+ model: DenoisingTransformer,
+) -> tuple[Callable, Callable]:
+ """Return ``(apply_eval, apply_train)`` closures matching ``ModelApplyFn``.
+
+ Args:
+ model: Flax module.
+
+ Returns:
+ Tuple of ``(apply_eval, apply_train)`` where ``apply_train`` enables
+ dropout via ``rngs={"dropout": rng}``.
+ """
+
+ def apply_eval(params: Any, obs: jnp.ndarray, z_t: jnp.ndarray, t: jnp.ndarray, _rng=None):
+ return model.apply(params, obs, z_t, t)
+
+ def apply_train(params: Any, obs: jnp.ndarray, z_t: jnp.ndarray, t: jnp.ndarray, rng=None):
+ return model.apply(
+ params, obs, z_t, t,
+ deterministic=False,
+ rngs={"dropout": rng} if rng is not None else {},
+ )
+
+ return apply_eval, apply_train
+
+
+# ---------------------------------------------------------------------------
+# Checkpoint metadata sidecar
+# ---------------------------------------------------------------------------
+
+
+class _NumpyEncoder(json.JSONEncoder):
+ """JSON encoder that handles numpy scalar types."""
+
+ def default(self, o: Any) -> Any:
+ """Serialize numpy scalars to native Python types.
+
+ Args:
+ o: Object to serialize.
+
+ Returns:
+ JSON-serializable object.
+ """
+ if isinstance(o, np.integer):
+ return int(o)
+ if isinstance(o, np.floating):
+ return float(o)
+ if isinstance(o, np.ndarray):
+ return o.tolist()
+ return super().default(o)
+
+
+def save_checkpoint_metadata(
+ checkpoint_dir: str,
+ mode: str,
+ update_step: int,
+ total_gradient_steps: int,
+ wandb_run_id: str | None,
+ config: dict[str, Any],
+) -> None:
+ """Write a JSON metadata sidecar alongside an Orbax checkpoint.
+
+ Args:
+ checkpoint_dir: Root directory of the Orbax checkpoint manager.
+ mode: Training mode (``"offline"`` or ``"online"``).
+ update_step: Final update step index.
+ total_gradient_steps: Total gradient steps completed.
+ wandb_run_id: Current W&B run ID, or ``None``.
+ config: Full training config snapshot.
+ """
+ metadata = {
+ "mode": mode,
+ "update_step": int(update_step),
+ "total_gradient_steps_completed": int(total_gradient_steps),
+ "wandb_run_id": wandb_run_id,
+ "config_snapshot": config,
+ }
+ path = Path(checkpoint_dir) / _METADATA_FILENAME
+ with open(path, "w") as f:
+ json.dump(metadata, f, indent=2, cls=_NumpyEncoder)
+ print(f"Saved checkpoint metadata to {path}")
+
+
+def load_checkpoint_metadata(
+ checkpoint_dir: str,
+) -> dict[str, Any] | None:
+ """Read the JSON metadata sidecar from a checkpoint directory.
+
+ Args:
+ checkpoint_dir: Root directory of the Orbax checkpoint manager.
+
+ Returns:
+ Parsed metadata dict, or ``None`` if the sidecar does not exist
+ (backward-compatible with checkpoints created before this feature).
+ """
+ path = Path(checkpoint_dir) / _METADATA_FILENAME
+ if not path.exists():
+ return None
+ with open(path) as f:
+ return json.load(f)
+
+
+def load_checkpoint_for_resume(
+ model: DenoisingTransformer,
+ rng: jax.Array,
+ obs_dim: int,
+ plan_horizon: int,
+ path: str,
+ lr_schedule: Union[float, Callable[[int], float]],
+ max_grad_norm: float,
+) -> TrainState:
+ """Load a full ``TrainState`` (params + optimizer state) for resume.
+
+ Unlike :func:`load_checkpoint` which returns only params, this function
+ restores the complete ``TrainState`` including Adam moments so that
+ training can continue seamlessly.
+
+ The ``lr_schedule`` and ``max_grad_norm`` must match the optimizer chain
+ structure used when the checkpoint was saved (same chain composition,
+ possibly different schedule values).
+
+ Args:
+ model: Flax module (used to build the abstract state).
+ rng: PRNG key for dummy initialisation.
+ obs_dim: Observation dimensionality.
+ plan_horizon: Number of action steps in a plan.
+ path: Path to the Orbax checkpoint directory.
+ lr_schedule: Learning rate or schedule matching the current run's
+ optimizer (must produce the same ``opt_state`` structure).
+ max_grad_norm: Global gradient clipping threshold.
+
+ Returns:
+ Restored ``TrainState`` with params, opt_state, and step from the
+ checkpoint. The caller should call ``.replace(step=...)`` to set the
+ correct LR offset for the resumed run.
+
+ Raises:
+ FileNotFoundError: If the checkpoint directory contains no saved steps.
+ """
+ path = str(Path(path).resolve())
+ params = init_params(model, rng, obs_dim, plan_horizon)
+ abstract_state = create_train_state(model, params, lr_schedule, max_grad_norm)
+
+ with ocp.CheckpointManager(path) as mgr:
+ step = mgr.latest_step()
+ if step is None:
+ raise FileNotFoundError(
+ f"No checkpoint found at {path}"
+ )
+ restored_state = mgr.restore(
+ step,
+ args=ocp.args.StandardRestore(item=abstract_state),
+ )
+
+ print(
+ f"Loaded full TrainState for resume from '{path}' "
+ f"(step {step}, opt_state step {restored_state.step})"
+ )
+ return restored_state
diff --git a/src/planners/offline.py b/src/planners/offline.py
new file mode 100644
index 0000000000000000000000000000000000000000..3963bab4ccfe68d5ff15bfa7a5b46d0e5f299306
--- /dev/null
+++ b/src/planners/offline.py
@@ -0,0 +1,357 @@
+"""Training loop: environment rollout -> diffusion window extraction -> gradient updates."""
+
+from __future__ import annotations
+
+import os
+import time
+from typing import Any
+
+import jax
+import jax.numpy as jnp
+import optax
+import orbax.checkpoint as ocp
+import wandb
+
+from src.diffusion.schedules import SCHEDULE_MAP
+from .common import (
+ make_grad_step,
+ make_validate,
+ print_config_snapshot,
+ resolve_num_updates,
+ resolve_scaled_hyperparams,
+)
+from .env import Transition, make_env
+from .model import (
+ build_model,
+ init_params,
+ create_train_state,
+ load_checkpoint_for_resume,
+ make_apply_fns,
+ save_checkpoint_metadata,
+)
+from .ppo import PPOAgent, build_ppo_network, load_ppo_params
+from .logging import init_wandb, make_wandb_callback
+
+
+# ---------------------------------------------------------------------------
+# make_train
+# ---------------------------------------------------------------------------
+
+def make_train(config: dict[str, Any]):
+ """Build the offline diffusion training closure.
+
+ All environment construction, model instantiation, and static pre-computation
+ happen here (outside the returned ``train`` closure) so they are not repeated
+ across ``jax.vmap`` replicas or JIT retraces.
+
+ Args:
+ config: Upper-cased hyperparameter dict (see ``configs/defaults.yaml``).
+
+ Returns:
+ A ``train(rng) -> dict`` closure that is safe to JIT and vmap.
+ """
+ num_steps = config["NUM_STEPS"]
+ num_envs = config["NUM_ENVS"]
+ plan_horizon = config["PLAN_HORIZON"]
+ val_interval = config.get("VAL_INTERVAL", 50)
+ val_replan_every = config.get("VAL_REPLAN_EVERY", 4)
+ val_steps = config.get("VAL_STEPS", 128)
+ n_val_cycles = val_steps // val_replan_every
+ valid_per_rollout = num_steps - plan_horizon + 1
+ num_samples = num_envs * valid_per_rollout
+ return_weight_cap = config.get("RETURN_WEIGHT_CAP", 5.0)
+
+ # NUM_UPDATES and OFFLINE_TOTAL_TIMESTEPS are resolved in
+ # run_offline_diffusion before wandb.init so the run name can use
+ # OFFLINE_TOTAL_TIMESTEPS. We assume both are present here.
+ assert num_samples % config["NUM_MINIBATCHES"] == 0, (
+ f"{num_samples} samples not divisible by {config['NUM_MINIBATCHES']} minibatches"
+ )
+ config["MINIBATCH_SIZE"] = num_samples // config["NUM_MINIBATCHES"]
+
+ # Environment
+ env, env_params = make_env(config, num_envs)
+
+ num_actions = env.action_space(env_params).n
+ obs_shape = env.observation_space(env_params).shape
+ obs_dim = obs_shape[0]
+
+ # PPO collector
+ model_type = config["PPO_MODEL_TYPE"]
+ ppo_net = build_ppo_network(model_type, num_actions, config["LAYER_SIZE"], config)
+ ppo_params = load_ppo_params(
+ config["PPO_CHECKPOINT_PATH"], ppo_net, model_type, num_envs, obs_shape, config["LAYER_SIZE"],
+ )
+ ppo = PPOAgent(ppo_net, ppo_params, model_type, config["LAYER_SIZE"])
+
+ # Noise schedule
+ schedule_fn, schedule_deriv_fn = SCHEDULE_MAP[config["DIFFUSION_SCHEDULE"]]
+
+ # Diffusion model — pure Flax dataclass, no randomness, safe to build once.
+ net = build_model(config, num_actions)
+ apply_eval, apply_train = make_apply_fns(net)
+ grad_step = make_grad_step(
+ apply_train, num_actions, schedule_fn, schedule_deriv_fn,
+ config.get("TRAIN_SIGMA", 0.0), config.get("LABEL_SMOOTHING", 0.0),
+ )
+
+ # Cosine LR decay over total gradient steps with optional linear warm-up.
+ total_grad_steps = config["NUM_UPDATES"] * config["UPDATE_EPOCHS"] * config["NUM_MINIBATCHES"]
+ warmup_steps = config.get("LR_WARMUP_STEPS", 0)
+ lr_schedule = (
+ optax.warmup_cosine_decay_schedule(
+ init_value=0.0,
+ peak_value=config["LR"],
+ warmup_steps=warmup_steps,
+ decay_steps=total_grad_steps,
+ end_value=config["LR"] * 0.1,
+ )
+ if warmup_steps > 0
+ else optax.cosine_decay_schedule(
+ init_value=config["LR"],
+ decay_steps=total_grad_steps,
+ alpha=0.1,
+ )
+ )
+
+ # Resume checkpoint (loaded outside JIT, captured by train closure) ------
+ resume_step = config.get("RESUME_STEP") or 0
+ resume_state = None
+ if config.get("RESUME_CHECKPOINT_PATH"):
+ resume_state = load_checkpoint_for_resume(
+ net,
+ jax.random.PRNGKey(0),
+ obs_dim,
+ plan_horizon,
+ config["RESUME_CHECKPOINT_PATH"],
+ lr_schedule,
+ config["MAX_GRAD_NORM"],
+ )
+ # Set the optimizer step counter so the LR schedule picks up at the
+ # correct position. The schedule is indexed by gradient step, which
+ # equals update_step * update_epochs * num_minibatches.
+ target_opt_step = resume_step * config["UPDATE_EPOCHS"] * config["NUM_MINIBATCHES"]
+ resume_state = resume_state.replace(step=target_opt_step)
+
+ scan_length = config["NUM_UPDATES"] - resume_step
+
+ # W&B callback — one closure shared across vmap replicas (timing is per-call).
+ _wandb_log = (
+ make_wandb_callback(
+ config,
+ steps_per_update=num_steps * num_envs,
+ val_interval=val_interval,
+ )
+ if config["USE_WANDB"] else None
+ )
+
+ def train(rng: jax.Array) -> dict[str, Any]:
+ """JIT/vmap-compatible training loop.
+
+ Args:
+ rng: JAX PRNG key (one per vmap replica).
+
+ Returns:
+ Dict with ``runner_state`` (final scan carry) and ``metrics`` (all update metrics).
+ """
+ rng, init_rng, env_rng = jax.random.split(rng, 3)
+ if resume_state is not None:
+ state = resume_state
+ else:
+ params = init_params(net, init_rng, obs_dim, plan_horizon)
+ state = create_train_state(net, params, lr_schedule, config["MAX_GRAD_NORM"])
+
+ obsv, env_state = env.reset(env_rng, env_params)
+ init_hstate = ppo.init_hidden(num_envs)
+
+ # Shared validation closure (see common.py)
+ _validate = make_validate(
+ env, env_params, apply_eval, num_actions,
+ plan_horizon, schedule_fn, config,
+ val_replan_every, n_val_cycles,
+ )
+
+ # ------------------------------------------------------------------
+ # Update step
+ # ------------------------------------------------------------------
+ def _update_step(runner, _):
+ state, env_state, last_obs, last_done, hstate, rng, step_idx = runner
+
+ # --- Trajectory collection (state excluded from carry: not modified here) ---
+ def _env_step(carry, _):
+ es, obs, done, hs, rng = carry
+ rng, act_rng, step_rng = jax.random.split(rng, 3)
+ action, new_hs = ppo.act(
+ obs, done, hs, act_rng, temperature=config.get("COLLECT_TEMPERATURE", 1.0),
+ )
+ new_obs, es, reward, new_done, info = env.step(step_rng, es, action, env_params)
+ t = Transition(done=done, action=action, reward=reward, obs=obs, info=info)
+ return (es, new_obs, new_done, new_hs, rng), t
+
+ (env_state, last_obs, last_done, hstate, rng), traj = jax.lax.scan(
+ _env_step, (env_state, last_obs, last_done, hstate, rng), None, num_steps,
+ )
+
+ # --- Diffusion window extraction ---
+ def _window(t_idx):
+ obs_t = traj.obs[t_idx]
+ acts = jax.lax.dynamic_slice(traj.action, (t_idx, 0), (plan_horizon, num_envs))
+ # traj.done[t] marks a reset *before* step t, so traj.done[t_idx]
+ # only tells us obs_t is an episode-start — it does NOT invalidate the
+ # window. We check done flags strictly *inside* the action sequence.
+ dones = jax.lax.dynamic_slice(
+ traj.done, (t_idx + 1, 0), (plan_horizon - 1, num_envs),
+ )
+ valid = ~jnp.any(dones, axis=0)
+
+ rew_seq = jax.lax.dynamic_slice(traj.reward, (t_idx, 0), (plan_horizon, num_envs))
+ window_return = jnp.sum(rew_seq, axis=0) # [num_envs]
+
+ return obs_t, jnp.swapaxes(acts, 0, 1), valid, window_return
+
+ obs_w, act_w, valid_w, returns_w = jax.vmap(_window)(jnp.arange(valid_per_rollout))
+
+ flat_obs = obs_w.reshape(-1, obs_dim)
+ flat_acts = act_w.reshape(-1, plan_horizon)
+ flat_valid = valid_w.reshape(-1) # bool: episode-boundary filter
+
+ # Return-weighted advantages: normalise by batch mean, clip to [0.1, cap].
+ # Passed as per-sample multipliers into compute_loss *after* per-position
+ # normalisation, so the weight correctly scales each sample's contribution.
+ flat_returns = returns_w.reshape(-1)
+ flat_returns_clipped = jnp.clip(flat_returns, 0.0, None)
+ return_weights = flat_returns_clipped / (jnp.mean(flat_returns_clipped) + 1e-8)
+ return_weights = jnp.clip(return_weights, 0.1, return_weight_cap)
+
+ dataset = (flat_obs, flat_acts, flat_valid, return_weights)
+
+ # --- Minibatch SGD over UPDATE_EPOCHS epochs ---
+ def _epoch(epoch_state, _):
+ state, ds, rng = epoch_state
+ rng, perm_rng = jax.random.split(rng)
+ perm = jax.random.permutation(perm_rng, num_samples)
+ shuffled = jax.tree.map(lambda x: jnp.take(x, perm, axis=0), ds)
+ batches = jax.tree.map(
+ lambda x: x.reshape(config["NUM_MINIBATCHES"], -1, *x.shape[1:]), shuffled,
+ )
+
+ def _mb(carry, batch):
+ st, rng = carry
+ rng, loss_rng = jax.random.split(rng)
+ obs_b, act_b, val_b, adv_b = batch
+ st, metrics = grad_step(st, act_b, obs_b, val_b, loss_rng, adv_b)
+ return (st, rng), metrics
+
+ (state, rng), metrics = jax.lax.scan(_mb, (state, rng), batches)
+ return (state, ds, rng), metrics
+
+ (state, _, rng), loss_info = jax.lax.scan(
+ _epoch, (state, dataset, rng), None, config["UPDATE_EPOCHS"],
+ )
+
+ # --- Metrics ---
+ metric = jax.tree.map(jnp.mean, loss_info)
+ returned = traj.info["returned_episode"]
+ env_metrics = jax.tree.map(
+ lambda x: (x * returned).sum() / (returned.sum() + 1e-8), traj.info,
+ )
+ metric.update(env_metrics)
+ metric["valid_frac"] = jnp.mean(flat_valid.astype(jnp.float32))
+ metric["mean_return_weight"] = jnp.mean(return_weights)
+
+ # --- Periodic validation ---
+ rng, val_rng = jax.random.split(rng)
+ dummy = jax.tree.map(
+ jnp.zeros_like, {f"val/{k}": v for k, v in env_metrics.items()},
+ )
+ val_metrics = jax.lax.cond(
+ step_idx % val_interval == 0,
+ lambda: _validate(state, val_rng),
+ lambda: dummy,
+ )
+ metric.update(val_metrics)
+
+ if _wandb_log is not None:
+ jax.debug.callback(_wandb_log, metric, step_idx)
+
+ runner = (state, env_state, last_obs, last_done, hstate, rng, step_idx + 1)
+ return runner, metric
+
+ rng, run_rng = jax.random.split(rng)
+ runner_init = (
+ state, env_state, obsv, jnp.zeros(num_envs, dtype=bool),
+ init_hstate, run_rng, resume_step,
+ )
+ runner_final, metrics = jax.lax.scan(_update_step, runner_init, None, scan_length)
+ return {"runner_state": runner_final, "metrics": metrics}
+
+ return train
+
+
+# ---------------------------------------------------------------------------
+# Entry point
+# ---------------------------------------------------------------------------
+
+def run_offline_diffusion(config):
+ """Configure, compile, and run offline diffusion training.
+
+ Args:
+ config: Mixed-case hyperparameter dict from ``defaults.yaml`` / CLI merge.
+ Keys are upper-cased on entry.
+ """
+ config = {k.upper(): v for k, v in config.items()}
+
+ # OFFLINE_TOTAL_TIMESTEPS (env frames) is the hardware-portable source of
+ # truth: invariant under num_envs changes, so the same config trains the
+ # same amount of environment experience on any GPU. OFFLINE_NUM_UPDATES
+ # is kept as a legacy fallback for configs that prefer the update form.
+ resolve_num_updates(config, "offline")
+ # Translate env-frame-denominated hyperparameters (LR_WARMUP_FRAMES,
+ # VAL_INTERVAL_FRAMES) into their update-step legacy keys.
+ resolve_scaled_hyperparams(config, "offline")
+ print_config_snapshot(config, "offline")
+
+ if config["USE_WANDB"]:
+ init_wandb(
+ config,
+ name=f"{config['ENV_NAME']}-OfflineDiffusion-BC-{int(config['OFFLINE_TOTAL_TIMESTEPS'] // 1e6)}M",
+ resume_run_id=config.get("RESUME_WANDB_RUN_ID"),
+ )
+
+ rng = jax.random.PRNGKey(config["SEED"])
+ rngs = jax.random.split(rng, config["NUM_REPEATS"])
+
+ train_fn = jax.jit(jax.vmap(make_train(config)))
+
+ t0 = time.time()
+ out = train_fn(rngs)
+ elapsed = time.time() - t0
+ print(f"Time: {elapsed:.1f}s SPS: {config['OFFLINE_TOTAL_TIMESTEPS'] / elapsed:.0f}")
+
+ if config["USE_WANDB"] and config["SAVE_POLICY"]:
+ train_states = out["runner_state"][0]
+ train_state = jax.tree.map(lambda x: x[0], train_states)
+ path = os.path.join(wandb.run.dir, "policies")
+ with ocp.CheckpointManager(path, options=ocp.CheckpointManagerOptions(max_to_keep=1)) as mgr:
+ mgr.save(int(config["OFFLINE_TOTAL_TIMESTEPS"]), args=ocp.args.StandardSave(train_state))
+ print(f"Saved policy to {path}")
+
+ num_updates = config["NUM_UPDATES"]
+ save_checkpoint_metadata(
+ path,
+ mode="offline",
+ update_step=num_updates,
+ total_gradient_steps=num_updates * config["UPDATE_EPOCHS"] * config["NUM_MINIBATCHES"],
+ wandb_run_id=wandb.run.id if wandb.run else None,
+ config=config,
+ )
+
+ artifact = wandb.Artifact(
+ name=f"{config['ENV_NAME']}-policy",
+ type="model",
+ metadata=config
+ )
+ artifact.add_dir(path)
+ wandb.log_artifact(artifact)
+
+ print("Uploaded policy artifact to wandb")
diff --git a/src/planners/online.py b/src/planners/online.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd16d40f80f09bd3ae0d7a80a3ff7ce7e908d689
--- /dev/null
+++ b/src/planners/online.py
@@ -0,0 +1,759 @@
+"""Online DAgger training: roll out learner, label with expert, aggregate, update.
+
+Implements Algorithm 3.1 from Ross et al. (2011) — 'A Reduction of Imitation
+Learning and Structured Prediction to No-Regret Online Learning'.
+
+Each DAgger iteration:
+ 1. Roll out the mixed policy (beta * expert + (1-beta) * learner).
+ 2. At every visited state, query the expert for target actions.
+ 3. Aggregate (obs, expert_plan) pairs into a growing replay buffer.
+ 4. Train the diffusion model on the full buffer with BC loss (MDLM ELBO).
+ beta decays exponentially so the learner's own policy dominates rollouts.
+"""
+
+from __future__ import annotations
+
+import os
+import time
+from typing import Any, NamedTuple
+
+import jax
+import jax.numpy as jnp
+import optax
+import orbax.checkpoint as ocp
+import wandb
+from flax.training.train_state import TrainState
+
+from src.diffusion.sampling import sample_plan
+from src.diffusion.schedules import SCHEDULE_MAP
+
+from .common import (
+ make_grad_step,
+ make_validate,
+ print_config_snapshot,
+ resolve_num_updates,
+ resolve_scaled_hyperparams,
+)
+from .env import make_env
+from .model import (
+ build_model,
+ init_params,
+ load_checkpoint,
+ load_checkpoint_for_resume,
+ create_train_state,
+ make_apply_fns,
+ save_checkpoint_metadata,
+)
+from .ppo import PPOAgent, load_ppo_agent
+from .logging import init_wandb, make_wandb_callback
+
+
+# ---------------------------------------------------------------------------
+# Scan carry
+# ---------------------------------------------------------------------------
+
+
+class DAggerCarry(NamedTuple):
+ """Carry state for the outer DAgger update scan."""
+
+ train_state: Any
+ env_state: Any
+ obs: jnp.ndarray # [E, obs_dim]
+ rng: jax.Array
+ step_idx: int
+ ppo_hs: Any # [E, layer_size] or None (non-RNN)
+ prev_done: jnp.ndarray # [E] bool
+ buf_obs: jnp.ndarray # [max_buf, obs_dim]
+ buf_plans: jnp.ndarray # [max_buf, plan_horizon] int32
+ buf_valid: jnp.ndarray # [max_buf] float32
+ buf_write_idx: int
+ buf_fill: int
+ best_params: Any # copy of params with highest val return
+ best_val_return: jnp.ndarray # scalar, -inf initially
+
+
+# ---------------------------------------------------------------------------
+# make_train_dagger
+# ---------------------------------------------------------------------------
+
+
+def make_train_dagger(config: dict[str, Any]):
+ """Build the DAgger train closure.
+
+ All environment construction, model instantiation, and static
+ pre-computation happen here (outside the returned ``train`` closure) so
+ they are not repeated across ``jax.vmap`` replicas or JIT retraces.
+
+ Args:
+ config: Upper-cased hyperparameter dict (see ``configs/defaults.yaml``).
+
+ Returns:
+ A ``train(rng) -> dict`` closure that is safe to JIT and vmap.
+ """
+ num_envs = config["NUM_ENVS"]
+ plan_horizon = config["PLAN_HORIZON"]
+ num_updates = config["NUM_UPDATES"]
+ update_epochs = config["UPDATE_EPOCHS"]
+ num_minibatches = config["NUM_MINIBATCHES"]
+ diffusion_steps = config["DIFFUSION_STEPS"]
+ num_steps = config["NUM_STEPS"]
+
+ # Validation config
+ val_interval = config.get("VAL_INTERVAL", 50)
+ val_replan_every = config.get("VAL_REPLAN_EVERY", 4)
+ val_steps = config.get("VAL_STEPS", 128)
+ n_val_cycles = val_steps // val_replan_every
+
+ # Environment ----------------------------------------------------------
+ env, env_params = make_env(config, num_envs)
+ num_actions = env.action_space(env_params).n
+ obs_shape = env.observation_space(env_params).shape
+ obs_dim = obs_shape[0]
+
+ # Expert (PPO) — required for DAgger -----------------------------------
+ assert config.get("PPO_CHECKPOINT_PATH"), (
+ "DAgger requires an expert policy; set PPO_CHECKPOINT_PATH."
+ )
+ ppo: PPOAgent = load_ppo_agent(
+ config["PPO_CHECKPOINT_PATH"],
+ num_actions,
+ obs_dim,
+ config.get("LAYER_SIZE", 512),
+ config.get("PPO_MODEL_TYPE", "ppo_rnn"),
+ config,
+ num_envs=num_envs,
+ )
+
+ # Schedule -------------------------------------------------------------
+ schedule_fn, schedule_deriv_fn = SCHEDULE_MAP[config["DIFFUSION_SCHEDULE"]]
+
+ # Diffusion model / apply fns ------------------------------------------
+ model = build_model(config, num_actions)
+ apply_eval, apply_train = make_apply_fns(model)
+ grad_step = make_grad_step(
+ apply_train,
+ num_actions,
+ schedule_fn,
+ schedule_deriv_fn,
+ config.get("TRAIN_SIGMA", 0.0),
+ config.get("LABEL_SMOOTHING", 0.0),
+ )
+
+ # Pretrained checkpoint ------------------------------------------------
+ pretrained_params = None
+ if config.get("OFFLINE_CHECKPOINT_PATH"):
+ _tmp_rng = jax.random.PRNGKey(0)
+ pretrained_params = load_checkpoint(
+ model,
+ _tmp_rng,
+ obs_dim,
+ plan_horizon,
+ config["OFFLINE_CHECKPOINT_PATH"],
+ )
+
+ # B3: roll out num_steps env transitions across n_cycles plans, then
+ # extract sliding windows the same way offline.py does — yields
+ # W = num_steps - plan_horizon + 1 windows per env per update instead
+ # of only one window per cycle (~16x denser for the default config).
+ assert num_steps % plan_horizon == 0, (
+ f"NUM_STEPS ({num_steps}) must be divisible by"
+ f" PLAN_HORIZON ({plan_horizon})"
+ )
+ assert num_steps >= plan_horizon, (
+ f"NUM_STEPS ({num_steps}) must be >= PLAN_HORIZON ({plan_horizon})"
+ )
+ n_cycles = num_steps // plan_horizon
+ valid_per_rollout = num_steps - plan_horizon + 1
+ samples_per_update = num_envs * valid_per_rollout
+
+ assert samples_per_update % num_minibatches == 0, (
+ f"samples_per_update ({samples_per_update}) not divisible by"
+ f" num_minibatches ({num_minibatches})"
+ )
+
+ # B1: circular replay buffer sizing.
+ # Theoretical max = num_updates * samples_per_update; cap to stay in
+ # GPU memory. Each sample stores obs (float32) + plan (int32) + valid.
+ max_buffer_size = min(
+ num_updates * samples_per_update,
+ config.get("DAGGER_BUFFER_MAX", 100_000),
+ )
+ assert samples_per_update <= max_buffer_size, (
+ f"samples_per_update ({samples_per_update}) exceeds"
+ f" max_buffer_size ({max_buffer_size}); raise DAGGER_BUFFER_MAX"
+ f" or shrink NUM_ENVS / NUM_STEPS"
+ )
+
+ # B1: multi-pass training over the buffer. Each pass redraws a
+ # fresh sample of size ``samples_per_update`` from the filled portion
+ # of D. Default = 1 so DAgger does exactly the same gradient work
+ # per update as offline BC (``update_epochs * num_minibatches`` grad
+ # steps over a sample of size ``samples_per_update``) — fair compute
+ # comparison. After the Bug 3 sliding-window fix, ``samples_per_update``
+ # is already 16x larger than the legacy cycle-only count, so a single
+ # pass already yields ~34% buffer coverage and cumulative coverage
+ # across a few updates approaches 100%. Set DAGGER_TRAIN_PASSES > 1
+ # explicitly to trade BC fairness for higher per-update D coverage.
+ _default_passes = 1
+ n_train_passes = config.get("DAGGER_TRAIN_PASSES") or _default_passes
+
+ # B2: deterministic-by-default expert. Sampling from ``pi.logits``
+ # injects label noise — two queries to the same state can return
+ # different actions — which breaks DAgger's assumption of a fixed
+ # expert mapping s -> a*. Argmax keeps labels consistent.
+ expert_deterministic = config.get("DAGGER_EXPERT_DETERMINISTIC", True)
+
+ # Beta schedule: probability of using expert for rollout actions.
+ # beta_i = beta_init * beta_decay^i -> 0 as i -> inf.
+ beta_init = config.get("DAGGER_BETA_INIT", 1.0)
+ beta_decay = config.get("DAGGER_BETA_DECAY", 0.95)
+
+ # I8: cosine LR schedule matching offline training. Stretched to
+ # cover all gradient steps across train passes (B1).
+ total_grad_steps = (
+ num_updates * n_train_passes * update_epochs * num_minibatches
+ )
+ warmup_steps = config.get("LR_WARMUP_STEPS", 0)
+ lr_schedule = (
+ optax.warmup_cosine_decay_schedule(
+ init_value=0.0,
+ peak_value=config["LR"],
+ warmup_steps=warmup_steps,
+ decay_steps=total_grad_steps,
+ end_value=config["LR"] * 0.1,
+ )
+ if warmup_steps > 0
+ else optax.cosine_decay_schedule(
+ init_value=config["LR"],
+ decay_steps=total_grad_steps,
+ alpha=0.1,
+ )
+ )
+
+ # Resume checkpoint (loaded outside JIT, captured by train closure) ------
+ resume_step = config.get("RESUME_STEP") or 0
+ resume_state = None
+ if config.get("RESUME_CHECKPOINT_PATH"):
+ resume_state = load_checkpoint_for_resume(
+ model,
+ jax.random.PRNGKey(0),
+ obs_dim,
+ plan_horizon,
+ config["RESUME_CHECKPOINT_PATH"],
+ lr_schedule,
+ config["MAX_GRAD_NORM"],
+ )
+ target_opt_step = (
+ resume_step * n_train_passes * update_epochs * num_minibatches
+ )
+ resume_state = resume_state.replace(step=target_opt_step)
+
+ scan_length = num_updates - resume_step
+
+ # W&B ------------------------------------------------------------------
+ _wandb_log = (
+ make_wandb_callback(
+ config,
+ steps_per_update=num_envs * num_steps,
+ val_interval=val_interval,
+ is_online=True,
+ )
+ if config.get("USE_WANDB")
+ else None
+ )
+
+ # ------------------------------------------------------------------
+ # Train closure
+ # ------------------------------------------------------------------
+
+ def train(rng: jax.Array) -> dict[str, Any]:
+ """JIT/vmap-compatible DAgger training loop.
+
+ Args:
+ rng: JAX PRNG key (one per vmap replica).
+
+ Returns:
+ Dict with ``runner_state`` (final DAggerCarry) and ``metrics``.
+ """
+ rng, init_rng, env_rng = jax.random.split(rng, 3)
+
+ if resume_state is not None:
+ state = resume_state
+ params = resume_state.params
+ elif pretrained_params is not None:
+ params = pretrained_params
+ state = create_train_state(
+ model, params, lr_schedule, config["MAX_GRAD_NORM"],
+ )
+ else:
+ params = init_params(model, init_rng, obs_dim, plan_horizon)
+ state = create_train_state(
+ model, params, lr_schedule, config["MAX_GRAD_NORM"],
+ )
+
+ obs, env_state = env.reset(env_rng, env_params)
+
+ # Pre-allocate DAgger replay buffer
+ buf_obs = jnp.zeros((max_buffer_size, obs_dim))
+ buf_plans = jnp.zeros(
+ (max_buffer_size, plan_horizon), dtype=jnp.int32,
+ )
+ buf_valid = jnp.zeros(max_buffer_size)
+
+ # Shared validation closure (see common.py)
+ _validate = make_validate(
+ env, env_params, apply_eval, num_actions,
+ plan_horizon, schedule_fn, config,
+ val_replan_every, n_val_cycles,
+ )
+
+ # ----------------------------------------------------------
+ # _update_step (one DAgger iteration)
+ # ----------------------------------------------------------
+ def _update_step(carry: DAggerCarry, _):
+ (
+ state,
+ env_state,
+ obs,
+ rng,
+ step_idx,
+ ppo_hs,
+ prev_done,
+ buf_obs,
+ buf_plans,
+ buf_valid,
+ buf_write_idx,
+ buf_fill,
+ best_params,
+ best_val_return,
+ ) = carry
+
+ # Beta decays each update: expert -> learner
+ beta = beta_init * jnp.power(beta_decay, step_idx)
+
+ # --- Roll out with mixed policy, collect expert labels -
+ def _plan_and_execute(outer_carry, _):
+ es, cur_obs, rng, hs, p_done = outer_carry
+ rng, plan_rng, sim_rng = jax.random.split(rng, 3)
+
+ # Learner plan from the current diffusion policy
+ learner_plan = sample_plan(
+ apply_eval,
+ state.params,
+ plan_rng,
+ cur_obs,
+ num_actions,
+ plan_horizon,
+ diffusion_steps,
+ schedule_fn,
+ config.get("REMASK_STRATEGY", "rescale"),
+ config.get("ETA", 0.5),
+ config.get("USE_LOOP", True),
+ config.get("T_ON", 0.7),
+ config.get("T_OFF", 0.3),
+ config.get("TEMPERATURE", 1.0),
+ config.get("TOP_P", 0.95),
+ ) # [E, H]
+
+ # B3: simulate plan_horizon steps, recording the visited
+ # obs at every state alongside the expert action. The
+ # outer code then extracts sliding windows from the full
+ # per-step trace, mirroring offline.py.
+ def _sim_step(c, step_i):
+ st, o, r, hs, p_done = c
+ r, s_rng, mix_rng, ppo_rng = jax.random.split(
+ r, 4,
+ )
+
+ # Expert action with the correct done flag so the
+ # PPO RNN hidden state resets on episode boundaries.
+ pi, new_hs = ppo.get_pi(o, p_done, hs)
+ if expert_deterministic:
+ # B2: argmax keeps the expert mapping s -> a
+ # deterministic, removing label noise from D.
+ expert_act = jnp.argmax(
+ pi.logits, axis=-1,
+ ).squeeze(0)
+ else:
+ expert_act = jax.random.categorical(
+ ppo_rng, pi.logits,
+ ).squeeze(0)
+
+ # Learner action from the plan
+ learner_act = learner_plan[:, step_i]
+
+ # Mixed execution: prob beta -> expert, else learner
+ use_expert = jax.random.bernoulli(
+ mix_rng, beta, shape=(num_envs,),
+ )
+ exec_act = jnp.where(
+ use_expert, expert_act, learner_act,
+ )
+
+ o_next, st, rew, done, info = env.step(
+ s_rng, st, exec_act, env_params,
+ )
+ # Yield the visited obs ``o`` (not ``o_next``) so the
+ # paired (obs_t, expert_act_t) is consistent.
+ return (st, o_next, r, new_hs, done), (
+ o,
+ expert_act,
+ rew,
+ done,
+ info,
+ )
+
+ final_c, (
+ obs_seq, expert_acts, rews, dones, infos,
+ ) = jax.lax.scan(
+ _sim_step,
+ (es, cur_obs, sim_rng, hs, p_done),
+ jnp.arange(plan_horizon),
+ )
+ # obs_seq: [H, E, obs_dim]
+ # expert_acts: [H, E]
+ # dones: [H, E]
+ es_next, obs_next, _, hs_next, done_next = final_c
+
+ return (es_next, obs_next, rng, hs_next, done_next), (
+ obs_seq,
+ expert_acts,
+ rews,
+ dones,
+ infos,
+ )
+
+ # I7: ppo_hs and prev_done persist across cycles and
+ # updates via the scan carry.
+ (env_state, obs, rng, ppo_hs, prev_done), traj = (
+ jax.lax.scan(
+ _plan_and_execute,
+ (env_state, obs, rng, ppo_hs, prev_done),
+ None,
+ n_cycles,
+ )
+ )
+ # traj_obs: [C, H, E, obs_dim]
+ # traj_expert_acts: [C, H, E]
+ # traj_dones: [C, H, E]
+ (
+ traj_obs,
+ traj_expert_acts,
+ traj_rew,
+ traj_dones,
+ all_infos,
+ ) = traj
+
+ # B3: concatenate cycles into one [T, E, ...] rollout. Cycles
+ # are contiguous in time so a flat reshape preserves order.
+ T = num_steps
+ obs_t = traj_obs.reshape(T, num_envs, obs_dim) # [T, E, D]
+ acts_t = traj_expert_acts.reshape(T, num_envs) # [T, E]
+ dones_t = traj_dones.reshape(T, num_envs) # [T, E]
+
+ # B3 + B5: sliding-window extraction matching offline.py.
+ # Window (t, e) is valid iff actions [t..t+H-1] all came from
+ # one episode, i.e. dones[t..t+H-2] are all False. An episode
+ # boundary on the *last* action is allowed — the window's
+ # action sequence is still a coherent trajectory.
+ def _window(t_idx):
+ obs_w = obs_t[t_idx] # [E, D]
+ acts_w = jax.lax.dynamic_slice(
+ acts_t, (t_idx, 0), (plan_horizon, num_envs),
+ ) # [H, E]
+ dones_w = jax.lax.dynamic_slice(
+ dones_t, (t_idx, 0), (plan_horizon - 1, num_envs),
+ ) # [H-1, E]
+ valid = ~jnp.any(dones_w, axis=0) # [E]
+ return obs_w, jnp.swapaxes(acts_w, 0, 1), valid
+
+ obs_w, act_w, valid_w = jax.vmap(_window)(
+ jnp.arange(valid_per_rollout),
+ )
+ # obs_w: [W, E, D]
+ # act_w: [W, E, H]
+ # valid_w: [W, E]
+ flat_obs = obs_w.reshape(-1, obs_dim)
+ flat_plans = act_w.reshape(-1, plan_horizon)
+ flat_valid = valid_w.reshape(-1).astype(jnp.float32)
+
+ # B1: write new samples into circular replay buffer
+ write_indices = (
+ buf_write_idx + jnp.arange(samples_per_update)
+ ) % max_buffer_size
+ buf_obs = buf_obs.at[write_indices].set(flat_obs)
+ buf_plans = buf_plans.at[write_indices].set(flat_plans)
+ buf_valid = buf_valid.at[write_indices].set(flat_valid)
+ buf_write_idx = (
+ buf_write_idx + samples_per_update
+ ) % max_buffer_size
+ buf_fill = jnp.minimum(
+ buf_fill + samples_per_update, max_buffer_size,
+ )
+
+ # B1: multi-pass training over the aggregated buffer. Each
+ # pass redraws a fresh sample of size ``samples_per_update``
+ # from the filled portion of the buffer; with the default
+ # ``n_train_passes`` ≈ ⌊|D|/B⌋ one update covers ~|D|
+ # examples once the buffer is full, restoring DAgger's
+ # "train on all of D" requirement without inflating gather
+ # memory beyond a single per-pass batch.
+ def _pass(pass_state, _):
+ state, rng = pass_state
+ rng, sample_rng = jax.random.split(rng)
+ buf_indices = jax.random.randint(
+ sample_rng, (samples_per_update,), 0, buf_fill,
+ )
+ dataset = (
+ buf_obs[buf_indices],
+ buf_plans[buf_indices],
+ buf_valid[buf_indices],
+ )
+
+ def _epoch(epoch_state, _):
+ state, ds, rng = epoch_state
+ rng, perm_rng = jax.random.split(rng)
+ perm = jax.random.permutation(
+ perm_rng, samples_per_update,
+ )
+ shuffled = jax.tree.map(
+ lambda x: jnp.take(x, perm, axis=0), ds,
+ )
+ batches = jax.tree.map(
+ lambda x: x.reshape(
+ num_minibatches, -1, *x.shape[1:],
+ ),
+ shuffled,
+ )
+
+ def _mb(mb_carry, batch):
+ st, rng = mb_carry
+ rng, loss_rng = jax.random.split(rng)
+ obs_b, act_b, val_b = batch
+ adv_b = jnp.ones(act_b.shape[0])
+ st, metrics = grad_step(
+ st,
+ act_b,
+ obs_b,
+ val_b,
+ loss_rng,
+ advantages=adv_b,
+ )
+ return (st, rng), metrics
+
+ (state, rng), metrics = jax.lax.scan(
+ _mb, (state, rng), batches,
+ )
+ return (state, ds, rng), metrics
+
+ (state, _, rng), loss_info = jax.lax.scan(
+ _epoch, (state, dataset, rng), None, update_epochs,
+ )
+ return (state, rng), loss_info
+
+ (state, rng), loss_info = jax.lax.scan(
+ _pass, (state, rng), None, n_train_passes,
+ )
+
+ # --- Metrics ------------------------------------------
+ metric = jax.tree.map(jnp.mean, loss_info)
+ returned = all_infos["returned_episode"]
+ env_metrics = jax.tree.map(
+ lambda x: (x * returned).sum()
+ / (returned.sum() + 1e-8),
+ all_infos,
+ )
+ metric.update(env_metrics)
+ metric["beta"] = beta
+ metric["reward_mean"] = jnp.mean(traj_rew)
+ metric["buffer_fill"] = buf_fill.astype(jnp.float32)
+ metric["valid_frac"] = jnp.mean(flat_valid)
+
+ # --- Periodic validation ------------------------------
+ rng, val_rng = jax.random.split(rng)
+ dummy = jax.tree.map(
+ jnp.zeros_like,
+ {f"val/{k}": v for k, v in env_metrics.items()},
+ )
+ val_metrics = jax.lax.cond(
+ step_idx % val_interval == 0,
+ lambda: _validate(state, val_rng),
+ lambda: dummy,
+ )
+ metric.update(val_metrics)
+
+ # Best-model tracking: update when validation improves.
+ val_ret = val_metrics.get(
+ "val/returned_episode_returns",
+ jnp.array(-jnp.inf),
+ )
+ is_val_step = step_idx % val_interval == 0
+ improved = is_val_step & (val_ret > best_val_return)
+ best_params = jax.tree.map(
+ lambda b, c: jnp.where(improved, c, b),
+ best_params,
+ state.params,
+ )
+ best_val_return = jnp.where(
+ improved, val_ret, best_val_return,
+ )
+ metric["best_val_return"] = best_val_return
+
+ if _wandb_log is not None:
+ jax.debug.callback(_wandb_log, metric, step_idx)
+
+ new_carry = DAggerCarry(
+ train_state=state,
+ env_state=env_state,
+ obs=obs,
+ rng=rng,
+ step_idx=step_idx + 1,
+ ppo_hs=ppo_hs,
+ prev_done=prev_done,
+ buf_obs=buf_obs,
+ buf_plans=buf_plans,
+ buf_valid=buf_valid,
+ buf_write_idx=buf_write_idx,
+ buf_fill=buf_fill,
+ best_params=best_params,
+ best_val_return=best_val_return,
+ )
+ return new_carry, metric
+
+ # --- Outer scan -------------------------------------------
+ rng, run_rng = jax.random.split(rng)
+ runner_init = DAggerCarry(
+ train_state=state,
+ env_state=env_state,
+ obs=obs,
+ rng=run_rng,
+ step_idx=resume_step,
+ ppo_hs=ppo.init_hidden(num_envs),
+ prev_done=jnp.zeros(num_envs, dtype=bool),
+ buf_obs=buf_obs,
+ buf_plans=buf_plans,
+ buf_valid=buf_valid,
+ buf_write_idx=jnp.int32(0),
+ buf_fill=jnp.int32(0),
+ best_params=params,
+ best_val_return=jnp.array(-jnp.inf),
+ )
+ runner_final, metrics = jax.lax.scan(
+ _update_step, runner_init, None, scan_length,
+ )
+ return {
+ "runner_state": runner_final,
+ "metrics": metrics,
+ "best_params": runner_final.best_params,
+ }
+
+ return train
+
+
+# ---------------------------------------------------------------------------
+# Entry point
+# ---------------------------------------------------------------------------
+
+
+def run_online(config: dict[str, Any]) -> None:
+ """Configure, compile, and run DAgger online training.
+
+ Args:
+ config: Mixed-case hyperparameter dict from ``defaults.yaml`` / CLI.
+ """
+ config = {k.upper(): v for k, v in config.items()}
+
+ # ONLINE_TOTAL_TIMESTEPS (env frames) is the hardware-portable source of
+ # truth: invariant under num_envs changes, so the same config trains the
+ # same amount of environment experience on any GPU. ONLINE_NUM_UPDATES is
+ # kept as a legacy fallback for configs that prefer the update form.
+ resolve_num_updates(config, "online")
+ # Translate env-frame-denominated hyperparameters (LR_WARMUP_FRAMES,
+ # VAL_INTERVAL_FRAMES, DAGGER_BETA_FINAL, DAGGER_BUFFER_CYCLES) into
+ # their update-step legacy keys. Must run AFTER resolve_num_updates
+ # because DAGGER_BETA_FINAL needs NUM_UPDATES.
+ resolve_scaled_hyperparams(config, "online")
+ print_config_snapshot(config, "online")
+
+ if config.get("USE_WANDB"):
+ init_wandb(
+ config,
+ name=f"{config['ENV_NAME']}-OnlineDiffusion-DAgger-{int(config['ONLINE_TOTAL_TIMESTEPS'] // 1e6)}M",
+ resume_run_id=config.get("RESUME_WANDB_RUN_ID"),
+ )
+
+ rng = jax.random.PRNGKey(config["SEED"])
+ rngs = jax.random.split(rng, config.get("NUM_REPEATS", 1))
+
+ train_fn = jax.jit(jax.vmap(make_train_dagger(config)))
+
+ t0 = time.time()
+ out = train_fn(rngs)
+ elapsed = time.time() - t0
+ print(f"Time: {elapsed:.1f}s SPS: {config['ONLINE_TOTAL_TIMESTEPS'] / elapsed:.0f}")
+
+ if config.get("USE_WANDB") and config.get("SAVE_POLICY"):
+ # Final checkpoint (last iteration params)
+ train_states = out["runner_state"].train_state
+ train_state = jax.tree.map(lambda x: x[0], train_states)
+ path = os.path.join(wandb.run.dir, "policies")
+ with ocp.CheckpointManager(
+ path,
+ options=ocp.CheckpointManagerOptions(max_to_keep=1),
+ ) as mgr:
+ mgr.save(
+ int(config["NUM_UPDATES"]),
+ args=ocp.args.StandardSave(train_state),
+ )
+ print(f"Saved final policy to {path}")
+
+ num_updates = config["NUM_UPDATES"]
+ save_checkpoint_metadata(
+ path,
+ mode="online",
+ update_step=num_updates,
+ total_gradient_steps=num_updates * config["UPDATE_EPOCHS"] * config["NUM_MINIBATCHES"],
+ wandb_run_id=wandb.run.id if wandb.run else None,
+ config=config,
+ )
+
+ artifact = wandb.Artifact(
+ name=f"{config['ENV_NAME']}-policy",
+ type="model",
+ metadata=config,
+ )
+ artifact.add_dir(path)
+ wandb.log_artifact(artifact)
+
+ # Best checkpoint (highest validation return).
+ # Wrap in a dummy TrainState so the Orbax structure matches
+ # the final checkpoint — load_checkpoint expects TrainState.
+ best_params = jax.tree.map(
+ lambda x: x[0], out["best_params"],
+ )
+ tx = optax.chain(
+ optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
+ optax.adam(config["LR"], eps=1e-5),
+ )
+ best_state = TrainState.create(
+ apply_fn=lambda *a: None,
+ params=best_params,
+ tx=tx,
+ )
+ best_path = os.path.join(wandb.run.dir, "policies_best")
+ with ocp.CheckpointManager(
+ best_path,
+ options=ocp.CheckpointManagerOptions(max_to_keep=1),
+ ) as mgr:
+ mgr.save(0, args=ocp.args.StandardSave(best_state))
+ print(f"Saved best policy to {best_path}")
+
+ best_artifact = wandb.Artifact(
+ name=f"{config['ENV_NAME']}-policy-best",
+ type="model",
+ metadata=config,
+ )
+ best_artifact.add_dir(best_path)
+ wandb.log_artifact(best_artifact)
+
+ print("Uploaded final + best policy artifacts to wandb")
diff --git a/src/planners/ppo.py b/src/planners/ppo.py
new file mode 100644
index 0000000000000000000000000000000000000000..fda1d481f92d51290f3549710975038b17969b88
--- /dev/null
+++ b/src/planners/ppo.py
@@ -0,0 +1,188 @@
+"""PPO agent adapter and checkpoint loading utilities."""
+
+from __future__ import annotations
+
+from pathlib import Path
+from typing import Any
+
+import jax
+import jax.numpy as jnp
+import numpy as np
+import orbax.checkpoint as ocp
+
+from Craftax_Baselines.ppo import ActorCritic
+from Craftax_Baselines.ppo_rnn import ActorCriticRNN
+from Craftax_Baselines.ppo_rnd import ActorCriticRND
+
+
+def load_ppo_params(
+ path: str,
+ network: Any,
+ model_type: str,
+ num_envs: int,
+ obs_shape: tuple,
+ layer_size: int = 512,
+) -> Any:
+ """Restore PPO parameters from an Orbax checkpoint.
+
+ Args:
+ path: Path to the Orbax checkpoint directory.
+ network: Instantiated Flax network (used only for structure).
+ model_type: One of ``"ppo_rnn"``, ``"ppo_rnd"``, or ``"ppo"``.
+ num_envs: Number of parallel environments (affects RNN init shape).
+ obs_shape: Observation shape tuple.
+ layer_size: Hidden layer size (for RNN hidden state init).
+
+ Returns:
+ Restored parameter pytree.
+ """
+ path = str(Path(path).resolve())
+ rng = jax.random.PRNGKey(0)
+ if model_type == "ppo_rnn":
+ init_x = (jnp.zeros((1, num_envs, *obs_shape)), jnp.zeros((1, num_envs)))
+ abstract = network.init(rng, jnp.zeros((num_envs, layer_size)), init_x)
+ else:
+ abstract = network.init(rng, jnp.zeros((1, *obs_shape)))
+
+ with ocp.CheckpointManager(path) as mgr:
+ step = mgr.latest_step()
+ if step is None:
+ raise FileNotFoundError(f"No checkpoint at {path}")
+ restored = mgr.restore(
+ step,
+ args=ocp.args.PyTreeRestore(item={"params": abstract}, partial_restore=True),
+ )
+ print(f"Loaded {model_type.upper()} checkpoint from '{path}' (step {step})")
+ return restored["params"]
+
+
+def build_ppo_network(model_type: str, num_actions: int, layer_size: int, config: dict) -> Any:
+ """Instantiate the correct PPO architecture.
+
+ Args:
+ model_type: One of ``"ppo_rnn"``, ``"ppo_rnd"``, or ``"ppo"``.
+ num_actions: Size of the discrete action space.
+ layer_size: Hidden layer width.
+ config: Training config (forwarded to ``ActorCriticRNN``).
+
+ Returns:
+ Flax module instance.
+ """
+ model_type = model_type.lower()
+ if model_type == "ppo_rnn":
+ return ActorCriticRNN(num_actions, config=config)
+ if model_type == "ppo_rnd":
+ return ActorCriticRND(num_actions, layer_size)
+ return ActorCritic(num_actions, layer_size)
+
+
+def load_ppo_agent(
+ path: str,
+ num_actions: int,
+ obs_dim: int,
+ layer_size: int,
+ model_type: str,
+ config: dict,
+ num_envs: int = 1,
+) -> "PPOAgent":
+ """Build network, load params, and return a :class:`PPOAgent`.
+
+ Args:
+ path: Path to the Orbax checkpoint directory.
+ num_actions: Size of the discrete action space.
+ obs_dim: Observation vector dimensionality.
+ layer_size: Hidden layer width.
+ model_type: One of ``"ppo_rnn"``, ``"ppo_rnd"``, or ``"ppo"``.
+ config: Training config dict.
+ num_envs: Number of parallel environments.
+
+ Returns:
+ A fully initialised :class:`PPOAgent`.
+ """
+ net = build_ppo_network(model_type, num_actions, layer_size, config)
+ params = load_ppo_params(path, net, model_type, num_envs, (obs_dim,), layer_size)
+ return PPOAgent(net, params, model_type, layer_size)
+
+
+class PPOAgent:
+ """Uniform interface over PPO-RNN / PPO / PPO-RND for action collection.
+
+ Args:
+ network: Flax actor-critic module.
+ params: Loaded parameter pytree.
+ model_type: One of ``"ppo_rnn"``, ``"ppo_rnd"``, or ``"ppo"``.
+ layer_size: Hidden layer width (used for RNN hidden-state shape).
+ """
+
+ def __init__(self, network: Any, params: Any, model_type: str, layer_size: int = 512) -> None:
+ self.network = network
+ self.params = params
+ self.model_type = model_type.lower()
+ self.layer_size = layer_size
+
+ def init_hidden(self, batch_size: int) -> jnp.ndarray | None:
+ """Return a zero hidden state for RNN models, else ``None``."""
+ if self.model_type == "ppo_rnn":
+ return jnp.zeros((batch_size, self.layer_size))
+ return None
+
+ def act(
+ self,
+ obs: jnp.ndarray,
+ done: jnp.ndarray,
+ hidden: jnp.ndarray | None,
+ rng: jax.Array,
+ temperature: float = 1.0,
+ ) -> tuple[jnp.ndarray, jnp.ndarray | None]:
+ """Sample an action.
+
+ Args:
+ obs: Observation array ``[B, obs_dim]``.
+ done: Episode-done flags ``[B]``.
+ hidden: RNN hidden state (``None`` for non-RNN models).
+ rng: PRNG key.
+ temperature: Softmax temperature for sampling.
+
+ Returns:
+ ``(action, new_hidden)`` tuple.
+ """
+ if self.model_type == "ppo_rnn":
+ ac_in = (obs[np.newaxis, :], done[np.newaxis, :])
+ new_hidden, pi, _ = self.network.apply(self.params, hidden, ac_in)
+ elif self.model_type == "ppo_rnd":
+ pi, _, _ = self.network.apply(self.params, obs)
+ new_hidden = hidden
+ else:
+ pi, _ = self.network.apply(self.params, obs)
+ new_hidden = hidden
+
+ action = jax.random.categorical(rng, pi.logits / temperature)
+ if self.model_type == "ppo_rnn":
+ action = action.squeeze(0)
+ return action, new_hidden
+
+ def get_pi(
+ self,
+ obs: jnp.ndarray,
+ done: jnp.ndarray | None = None,
+ hidden: jnp.ndarray | None = None,
+ ) -> tuple[Any, jnp.ndarray | None]:
+ """Return the policy distribution (used in DAgger expert labelling).
+
+ Args:
+ obs: Observation array ``[B, obs_dim]``.
+ done: Episode-done flags ``[B]`` (required for RNN models).
+ hidden: RNN hidden state.
+
+ Returns:
+ ``(pi, new_hidden)`` tuple.
+ """
+ if self.model_type == "ppo_rnn":
+ ac_in = (obs[np.newaxis, :], done[np.newaxis, :])
+ new_hidden, pi, _ = self.network.apply(self.params, hidden, ac_in)
+ return pi, new_hidden
+ if self.model_type == "ppo_rnd":
+ pi, _, _ = self.network.apply(self.params, obs)
+ return pi, hidden
+ pi, _ = self.network.apply(self.params, obs)
+ return pi, hidden